Merge branch 'usememos:main' into feature/button-icon-divider

This commit is contained in:
gkmzfk6um 2026-02-02 09:32:42 +01:00 committed by GitHub
commit 183fdb3c57
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
222 changed files with 6610 additions and 2610 deletions

View File

@ -1 +1,13 @@
web/node_modules
web/dist
.git
.github
build/
tmp/
memos
*.md
.gitignore
.golangci.yaml
.dockerignore
docs/
.DS_Store

View File

@ -11,37 +11,79 @@ on:
- "go.sum"
- "**.go"
concurrency:
group: ${{ github.workflow }}-${{ github.ref }}
cancel-in-progress: true
jobs:
go-static-checks:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v5
- uses: actions/checkout@v6
- uses: actions/setup-go@v6
with:
go-version: 1.25
check-latest: true
cache: true
cache-dependency-path: go.sum
- name: Verify go.mod is tidy
run: |
go mod tidy -go=1.25
git diff --exit-code
- name: golangci-lint
uses: golangci/golangci-lint-action@v8
uses: golangci/golangci-lint-action@v9
with:
version: v2.4.0
args: --verbose --timeout=3m
skip-cache: true
args: --timeout=3m
go-tests:
runs-on: ubuntu-latest
strategy:
fail-fast: false
matrix:
test-group:
- store
- server
- plugin
- other
steps:
- uses: actions/checkout@v5
- uses: actions/checkout@v6
- uses: actions/setup-go@v6
with:
go-version: 1.25
check-latest: true
cache: true
- name: Run all tests
run: go test -v ./... | tee test.log; exit ${PIPESTATUS[0]}
- name: Pretty print tests running time
run: grep --color=never -e '--- PASS:' -e '--- FAIL:' test.log | sed 's/[:()]//g' | awk '{print $2,$3,$4}' | sort -t' ' -nk3 -r | awk '{sum += $3; print $1,$2,$3,sum"s"}'
cache-dependency-path: go.sum
- name: Run tests - ${{ matrix.test-group }}
run: |
case "${{ matrix.test-group }}" in
store)
# Run store tests for all drivers (sqlite, mysql, postgres)
# The TestMain in store/test runs all drivers when DRIVER is not set
# Note: We run without -race for container tests due to testcontainers race issues
go test -v -coverprofile=coverage.out -covermode=atomic ./store/...
;;
server)
go test -v -race -coverprofile=coverage.out -covermode=atomic ./server/...
;;
plugin)
go test -v -race -coverprofile=coverage.out -covermode=atomic ./plugin/...
;;
other)
go test -v -race -coverprofile=coverage.out -covermode=atomic \
./cmd/... ./internal/... ./proto/...
;;
esac
env:
DRIVER: ${{ matrix.test-group == 'store' && '' || 'sqlite' }}
- name: Upload coverage
if: github.event_name == 'push' && github.ref == 'refs/heads/main'
uses: codecov/codecov-action@v5
with:
files: ./coverage.out
flags: ${{ matrix.test-group }}
fail_ci_if_error: false

View File

@ -4,36 +4,128 @@ on:
push:
branches: [main]
env:
DOCKER_PLATFORMS: |
linux/amd64
linux/arm64
concurrency:
group: ${{ github.workflow }}-${{ github.repository }}
cancel-in-progress: true
jobs:
build-and-push-canary-image:
build-frontend:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v6
- uses: pnpm/action-setup@v4.2.0
with:
version: 10
- uses: actions/setup-node@v6
with:
node-version: "22"
cache: pnpm
cache-dependency-path: "web/pnpm-lock.yaml"
- name: Get pnpm store directory
id: pnpm-cache
shell: bash
run: echo "STORE_PATH=$(pnpm store path)" >> $GITHUB_OUTPUT
- name: Setup pnpm cache
uses: actions/cache@v5
with:
path: ${{ steps.pnpm-cache.outputs.STORE_PATH }}
key: ${{ runner.os }}-pnpm-store-${{ hashFiles('web/pnpm-lock.yaml') }}
restore-keys: ${{ runner.os }}-pnpm-store-
- run: pnpm install --frozen-lockfile
working-directory: web
- name: Run frontend build
run: pnpm release
working-directory: web
- name: Upload frontend artifacts
uses: actions/upload-artifact@v6
with:
name: frontend-dist
path: server/router/frontend/dist
retention-days: 1
build-push:
needs: build-frontend
runs-on: ubuntu-latest
permissions:
contents: read
packages: write
strategy:
fail-fast: false
matrix:
platform:
- linux/amd64
- linux/arm64
steps:
- uses: actions/checkout@v6
- name: Download frontend artifacts
uses: actions/download-artifact@v7
with:
name: frontend-dist
path: server/router/frontend/dist
- name: Set up QEMU
uses: docker/setup-qemu-action@v3
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3
- name: Login to Docker Hub
uses: docker/login-action@v3
with:
username: ${{ secrets.DOCKER_HUB_USERNAME }}
password: ${{ secrets.DOCKER_HUB_TOKEN }}
- name: Login to GitHub Container Registry
uses: docker/login-action@v3
with:
registry: ghcr.io
username: ${{ github.actor }}
password: ${{ github.token }}
- name: Build and push by digest
id: build
uses: docker/build-push-action@v6
with:
context: .
file: ./scripts/Dockerfile
platforms: ${{ matrix.platform }}
cache-from: type=gha,scope=build-${{ matrix.platform }}
cache-to: type=gha,mode=max,scope=build-${{ matrix.platform }}
outputs: type=image,name=neosmemo/memos,push-by-digest=true,name-canonical=true,push=true
- name: Export digest
run: |
mkdir -p /tmp/digests
digest="${{ steps.build.outputs.digest }}"
touch "/tmp/digests/${digest#sha256:}"
- name: Upload digest
uses: actions/upload-artifact@v6
with:
name: digests-${{ strategy.job-index }}
path: /tmp/digests/*
if-no-files-found: error
retention-days: 1
merge:
needs: build-push
runs-on: ubuntu-latest
permissions:
contents: read
packages: write
steps:
- uses: actions/checkout@v5
- name: Set up QEMU
uses: docker/setup-qemu-action@v3
- name: Download digests
uses: actions/download-artifact@v7
with:
platforms: ${{ env.DOCKER_PLATFORMS }}
pattern: digests-*
merge-multiple: true
path: /tmp/digests
- name: Set up Docker Buildx
id: buildx
uses: docker/setup-buildx-action@v3
with:
version: latest
install: true
platforms: ${{ env.DOCKER_PLATFORMS }}
- name: Docker meta
id: meta
@ -60,32 +152,15 @@ jobs:
username: ${{ github.actor }}
password: ${{ github.token }}
# Frontend build.
- uses: pnpm/action-setup@v4.1.0
with:
version: 10
- uses: actions/setup-node@v5
with:
node-version: "22"
cache: pnpm
cache-dependency-path: "web/pnpm-lock.yaml"
- run: pnpm install
working-directory: web
- name: Run frontend build
run: pnpm release
working-directory: web
- name: Create manifest list and push
working-directory: /tmp/digests
run: |
docker buildx imagetools create $(jq -cr '.tags | map("-t " + .) | join(" ")' <<< "$DOCKER_METADATA_OUTPUT_JSON") \
$(printf 'neosmemo/memos@sha256:%s ' *)
env:
DOCKER_METADATA_OUTPUT_JSON: ${{ steps.meta.outputs.json }}
- name: Build and Push
id: docker_build
uses: docker/build-push-action@v6
with:
context: .
file: ./scripts/Dockerfile
platforms: ${{ env.DOCKER_PLATFORMS }}
push: true
tags: ${{ steps.meta.outputs.tags }}
labels: ${{ steps.meta.outputs.labels }}
cache-from: type=gha
cache-to: type=gha,mode=max
build-args: |
BUILDKIT_INLINE_CACHE=1
- name: Inspect images
run: |
docker buildx imagetools inspect neosmemo/memos:canary
docker buildx imagetools inspect ghcr.io/usememos/memos:canary

View File

@ -7,41 +7,84 @@ on:
tags:
- "v*.*.*"
env:
DOCKER_PLATFORMS: |
linux/amd64
linux/arm/v7
linux/arm64
jobs:
build-and-push-image:
prepare:
runs-on: ubuntu-latest
outputs:
version: ${{ steps.version.outputs.version }}
steps:
- name: Extract version
id: version
run: |
if [[ "$GITHUB_REF_TYPE" == "tag" ]]; then
echo "version=${GITHUB_REF_NAME#v}" >> $GITHUB_OUTPUT
else
echo "version=${GITHUB_REF_NAME#release/}" >> $GITHUB_OUTPUT
fi
build-frontend:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v6
- uses: pnpm/action-setup@v4.2.0
with:
version: 10
- uses: actions/setup-node@v6
with:
node-version: "22"
cache: pnpm
cache-dependency-path: "web/pnpm-lock.yaml"
- name: Get pnpm store directory
id: pnpm-cache
shell: bash
run: echo "STORE_PATH=$(pnpm store path)" >> $GITHUB_OUTPUT
- name: Setup pnpm cache
uses: actions/cache@v5
with:
path: ${{ steps.pnpm-cache.outputs.STORE_PATH }}
key: ${{ runner.os }}-pnpm-store-${{ hashFiles('web/pnpm-lock.yaml') }}
restore-keys: ${{ runner.os }}-pnpm-store-
- run: pnpm install --frozen-lockfile
working-directory: web
- name: Run frontend build
run: pnpm release
working-directory: web
- name: Upload frontend artifacts
uses: actions/upload-artifact@v6
with:
name: frontend-dist
path: server/router/frontend/dist
retention-days: 1
build-push:
needs: [prepare, build-frontend]
runs-on: ubuntu-latest
permissions:
contents: read
packages: write
strategy:
fail-fast: false
matrix:
platform:
- linux/amd64
- linux/arm/v7
- linux/arm64
steps:
- uses: actions/checkout@v5
- uses: actions/checkout@v6
- name: Download frontend artifacts
uses: actions/download-artifact@v7
with:
name: frontend-dist
path: server/router/frontend/dist
- name: Set up QEMU
uses: docker/setup-qemu-action@v3
with:
platforms: ${{ env.DOCKER_PLATFORMS }}
- name: Set up Docker Buildx
id: buildx
uses: docker/setup-buildx-action@v3
with:
version: latest
install: true
platforms: ${{ env.DOCKER_PLATFORMS }}
- name: Extract version
run: |
if [[ "$GITHUB_REF_TYPE" == "tag" ]]; then
echo "VERSION=${GITHUB_REF_NAME#v}" >> $GITHUB_ENV
else
echo "VERSION=${GITHUB_REF_NAME#release/}" >> $GITHUB_ENV
fi
- name: Login to Docker Hub
uses: docker/login-action@v3
@ -56,6 +99,48 @@ jobs:
username: ${{ github.actor }}
password: ${{ github.token }}
- name: Build and push by digest
id: build
uses: docker/build-push-action@v6
with:
context: .
file: ./scripts/Dockerfile
platforms: ${{ matrix.platform }}
cache-from: type=gha,scope=build-${{ matrix.platform }}
cache-to: type=gha,mode=max,scope=build-${{ matrix.platform }}
outputs: type=image,name=neosmemo/memos,push-by-digest=true,name-canonical=true,push=true
- name: Export digest
run: |
mkdir -p /tmp/digests
digest="${{ steps.build.outputs.digest }}"
touch "/tmp/digests/${digest#sha256:}"
- name: Upload digest
uses: actions/upload-artifact@v6
with:
name: digests-${{ strategy.job-index }}
path: /tmp/digests/*
if-no-files-found: error
retention-days: 1
merge:
needs: [prepare, build-push]
runs-on: ubuntu-latest
permissions:
contents: read
packages: write
steps:
- name: Download digests
uses: actions/download-artifact@v7
with:
pattern: digests-*
merge-multiple: true
path: /tmp/digests
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3
- name: Docker meta
id: meta
uses: docker/metadata-action@v5
@ -64,40 +149,36 @@ jobs:
neosmemo/memos
ghcr.io/usememos/memos
tags: |
type=semver,pattern={{version}},value=${{ env.VERSION }}
type=semver,pattern={{major}}.{{minor}},value=${{ env.VERSION }}
type=semver,pattern={{version}},value=${{ needs.prepare.outputs.version }}
type=semver,pattern={{major}}.{{minor}},value=${{ needs.prepare.outputs.version }}
type=raw,value=stable
flavor: |
latest=false
labels: |
org.opencontainers.image.version=${{ env.VERSION }}
org.opencontainers.image.version=${{ needs.prepare.outputs.version }}
# Frontend build.
- uses: pnpm/action-setup@v4.1.0
- name: Login to Docker Hub
uses: docker/login-action@v3
with:
version: 10
- uses: actions/setup-node@v5
with:
node-version: "22"
cache: pnpm
cache-dependency-path: "web/pnpm-lock.yaml"
- run: pnpm install
working-directory: web
- name: Run frontend build
run: pnpm release
working-directory: web
username: ${{ secrets.DOCKER_HUB_USERNAME }}
password: ${{ secrets.DOCKER_HUB_TOKEN }}
- name: Build and Push
id: docker_build
uses: docker/build-push-action@v6
- name: Login to GitHub Container Registry
uses: docker/login-action@v3
with:
context: .
file: ./scripts/Dockerfile
platforms: ${{ env.DOCKER_PLATFORMS }}
push: true
tags: ${{ steps.meta.outputs.tags }}
labels: ${{ steps.meta.outputs.labels }}
cache-from: type=gha
cache-to: type=gha,mode=max
build-args: |
BUILDKIT_INLINE_CACHE=1
registry: ghcr.io
username: ${{ github.actor }}
password: ${{ github.token }}
- name: Create manifest list and push
working-directory: /tmp/digests
run: |
docker buildx imagetools create $(jq -cr '.tags | map("-t " + .) | join(" ")' <<< "$DOCKER_METADATA_OUTPUT_JSON") \
$(printf 'neosmemo/memos@sha256:%s ' *)
env:
DOCKER_METADATA_OUTPUT_JSON: ${{ steps.meta.outputs.json }}
- name: Inspect images
run: |
docker buildx imagetools inspect neosmemo/memos:stable
docker buildx imagetools inspect ghcr.io/usememos/memos:stable

View File

@ -13,11 +13,11 @@ jobs:
static-checks:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v5
- uses: pnpm/action-setup@v4.1.0
- uses: actions/checkout@v6
- uses: pnpm/action-setup@v4.2.0
with:
version: 9
- uses: actions/setup-node@v5
- uses: actions/setup-node@v6
with:
node-version: "20"
cache: pnpm
@ -31,11 +31,11 @@ jobs:
frontend-build:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v5
- uses: pnpm/action-setup@v4.1.0
- uses: actions/checkout@v6
- uses: pnpm/action-setup@v4.2.0
with:
version: 9
- uses: actions/setup-node@v5
- uses: actions/setup-node@v6
with:
node-version: "20"
cache: pnpm

View File

@ -14,7 +14,7 @@ jobs:
runs-on: ubuntu-latest
steps:
- name: Checkout
uses: actions/checkout@v5
uses: actions/checkout@v6
with:
fetch-depth: 0
- name: Setup buf

View File

@ -11,7 +11,7 @@ jobs:
issues: write
steps:
- uses: actions/stale@v10.0.0
- uses: actions/stale@v10.1.1
with:
days-before-issue-stale: 14
days-before-issue-close: 7

View File

@ -165,7 +165,7 @@ type Driver interface {
4. Demo mode: Seed with demo data
**Schema Versioning:**
- Stored in `instance_setting` table (key: `bb.general.version`)
- Stored in `system_setting` table
- Format: `major.minor.patch`
- Migration files: `store/migration/{driver}/{version}/NN__description.sql`
- See: `store/migrator.go:21-414`
@ -191,7 +191,7 @@ cd proto && buf generate
```bash
# Start dev server
go run ./cmd/memos --mode dev --port 8081
go run ./cmd/memos --port 8081
# Run all tests
go test ./...
@ -458,7 +458,7 @@ cd web && pnpm lint
| Variable | Default | Description |
|----------|----------|-------------|
| `MEMOS_MODE` | `dev` | Mode: `dev`, `prod`, `demo` |
| `MEMOS_DEMO` | `false` | Enable demo mode |
| `MEMOS_PORT` | `8081` | HTTP port |
| `MEMOS_ADDR` | `` | Bind address (empty = all) |
| `MEMOS_DATA` | `~/.memos` | Data directory |
@ -503,13 +503,6 @@ cd web && pnpm lint
## Common Tasks
### Debugging Database Issues
1. Check connection string in logs
2. Verify `store/db/{driver}/migration/` files exist
3. Check schema version: `SELECT * FROM instance_setting WHERE key = 'bb.general.version'`
4. Test migration: `go test ./store/test/... -v`
### Debugging API Issues
1. Check Connect interceptor logs: `server/router/api/v1/connect_interceptors.go:79-105`
@ -571,7 +564,7 @@ Each plugin has its own README with usage examples.
## Security Notes
- JWT secrets must be kept secret (`MEMOS_MODE=prod` generates random secret)
- JWT secrets must be kept secret (generated on first run in production mode)
- Personal Access Tokens stored as SHA-256 hashes in database
- CSRF protection via SameSite cookies
- CORS enabled for all origins (configure for production)

View File

@ -22,10 +22,10 @@ An open-source, self-hosted note-taking service. Your thoughts, your data, your
---
[**LambdaTest** - Cross-browser testing cloud](https://www.lambdatest.com/?utm_source=memos&utm_medium=sponsor)
[**TestMu AI** - The worlds first full-stack Agentic AI Quality Engineering platform](https://www.testmu.ai/?utm_source=memos&utm_medium=sponsor)
<a href="https://www.lambdatest.com/?utm_source=memos&utm_medium=sponsor" target="_blank" rel="noopener">
<img src="https://www.lambdatest.com/blue-logo.png" alt="LambdaTest - Cross-browser testing cloud" height="50" />
<a href="https://www.testmu.ai/?utm_source=memos&utm_medium=sponsor" target="_blank" rel="noopener">
<img src="https://usememos.com/sponsors/testmu.svg" alt="TestMu AI" height="36" />
</a>
## Overview

View File

@ -25,7 +25,7 @@ var (
Short: `An open source, lightweight note-taking service. Easily capture and share your great thoughts.`,
Run: func(_ *cobra.Command, _ []string) {
instanceProfile := &profile.Profile{
Mode: viper.GetString("mode"),
Demo: viper.GetBool("demo"),
Addr: viper.GetString("addr"),
Port: viper.GetInt("port"),
UNIXSock: viper.GetString("unix-sock"),
@ -33,10 +33,12 @@ var (
Driver: viper.GetString("driver"),
DSN: viper.GetString("dsn"),
InstanceURL: viper.GetString("instance-url"),
Version: version.GetCurrentVersion(viper.GetString("mode")),
}
instanceProfile.Version = version.GetCurrentVersion()
if err := instanceProfile.Validate(); err != nil {
panic(err)
slog.Error("failed to validate profile", "error", err)
return
}
ctx, cancel := context.WithCancel(context.Background())
@ -71,6 +73,7 @@ var (
if err != http.ErrServerClosed {
slog.Error("failed to start server", "error", err)
cancel()
return
}
}
@ -89,11 +92,11 @@ var (
)
func init() {
viper.SetDefault("mode", "dev")
viper.SetDefault("demo", false)
viper.SetDefault("driver", "sqlite")
viper.SetDefault("port", 8081)
rootCmd.PersistentFlags().String("mode", "dev", `mode of server, can be "prod" or "dev" or "demo"`)
rootCmd.PersistentFlags().Bool("demo", false, "enable demo mode")
rootCmd.PersistentFlags().String("addr", "", "address of server")
rootCmd.PersistentFlags().Int("port", 8081, "port of server")
rootCmd.PersistentFlags().String("unix-sock", "", "path to the unix socket, overrides --addr and --port")
@ -102,7 +105,7 @@ func init() {
rootCmd.PersistentFlags().String("dsn", "", "database source name(aka. DSN)")
rootCmd.PersistentFlags().String("instance-url", "", "the url of your memos instance")
if err := viper.BindPFlag("mode", rootCmd.PersistentFlags().Lookup("mode")); err != nil {
if err := viper.BindPFlag("demo", rootCmd.PersistentFlags().Lookup("demo")); err != nil {
panic(err)
}
if err := viper.BindPFlag("addr", rootCmd.PersistentFlags().Lookup("addr")); err != nil {
@ -129,15 +132,12 @@ func init() {
viper.SetEnvPrefix("memos")
viper.AutomaticEnv()
if err := viper.BindEnv("instance-url", "MEMOS_INSTANCE_URL"); err != nil {
panic(err)
}
}
func printGreetings(profile *profile.Profile) {
fmt.Printf("Memos %s started successfully!\n", profile.Version)
if profile.IsDev() {
if profile.Demo {
fmt.Fprint(os.Stderr, "Development mode is enabled\n")
if profile.DSN != "" {
fmt.Fprintf(os.Stderr, "Database: %s\n", profile.DSN)
@ -147,7 +147,6 @@ func printGreetings(profile *profile.Profile) {
// Server information
fmt.Printf("Data directory: %s\n", profile.Data)
fmt.Printf("Database driver: %s\n", profile.Driver)
fmt.Printf("Mode: %s\n", profile.Mode)
// Connection information
if len(profile.UNIXSock) == 0 {
@ -170,6 +169,6 @@ func printGreetings(profile *profile.Profile) {
func main() {
if err := rootCmd.Execute(); err != nil {
panic(err)
os.Exit(1)
}
}

2
go.mod
View File

@ -9,6 +9,7 @@ require (
github.com/aws/aws-sdk-go-v2/credentials v1.18.16
github.com/aws/aws-sdk-go-v2/feature/s3/manager v1.19.4
github.com/aws/aws-sdk-go-v2/service/s3 v1.87.3
github.com/docker/docker v28.5.1+incompatible
github.com/go-sql-driver/mysql v1.9.3
github.com/google/cel-go v0.26.1
github.com/google/uuid v1.6.0
@ -50,7 +51,6 @@ require (
github.com/containerd/platforms v0.2.1 // indirect
github.com/cpuguy83/dockercfg v0.3.2 // indirect
github.com/distribution/reference v0.6.0 // indirect
github.com/docker/docker v28.5.1+incompatible // indirect
github.com/docker/go-connections v0.6.0 // indirect
github.com/docker/go-units v0.5.0 // indirect
github.com/dustin/go-humanize v1.0.1 // indirect

View File

@ -13,8 +13,8 @@ import (
// Profile is the configuration to start main server.
type Profile struct {
// Mode can be "prod" or "dev" or "demo"
Mode string
// Demo indicates if the server is in demo mode
Demo bool
// Addr is the binding address for server
Addr string
// Port is the binding port for server
@ -34,15 +34,12 @@ type Profile struct {
InstanceURL string
}
func (p *Profile) IsDev() bool {
return p.Mode != "prod"
}
func checkDataDir(dataDir string) (string, error) {
// Convert to absolute path if relative path is supplied.
if !filepath.IsAbs(dataDir) {
relativeDir := filepath.Join(filepath.Dir(os.Args[0]), dataDir)
absDir, err := filepath.Abs(relativeDir)
// Use current working directory, not the binary's directory
// This ensures we use the actual working directory where the process runs
absDir, err := filepath.Abs(dataDir)
if err != nil {
return "", err
}
@ -58,21 +55,35 @@ func checkDataDir(dataDir string) (string, error) {
}
func (p *Profile) Validate() error {
if p.Mode != "demo" && p.Mode != "dev" && p.Mode != "prod" {
p.Mode = "demo"
}
if p.Mode == "prod" && p.Data == "" {
// Set default data directory if not specified
if p.Data == "" {
if runtime.GOOS == "windows" {
p.Data = filepath.Join(os.Getenv("ProgramData"), "memos")
if _, err := os.Stat(p.Data); os.IsNotExist(err) {
if err := os.MkdirAll(p.Data, 0770); err != nil {
slog.Error("failed to create data directory", slog.String("data", p.Data), slog.String("error", err.Error()))
return err
}
}
} else {
p.Data = "/var/opt/memos"
// On Linux/macOS, check if /var/opt/memos exists and is writable (Docker scenario)
if info, err := os.Stat("/var/opt/memos"); err == nil && info.IsDir() {
// Check if we can write to this directory
testFile := filepath.Join("/var/opt/memos", ".write-test")
if err := os.WriteFile(testFile, []byte("test"), 0600); err == nil {
os.Remove(testFile)
p.Data = "/var/opt/memos"
} else {
// /var/opt/memos exists but is not writable, use current directory
slog.Warn("/var/opt/memos is not writable, using current directory")
p.Data = "."
}
} else {
// /var/opt/memos doesn't exist, use current directory (local development)
p.Data = "."
}
}
}
// Create data directory if it doesn't exist
if _, err := os.Stat(p.Data); os.IsNotExist(err) {
if err := os.MkdirAll(p.Data, 0770); err != nil {
slog.Error("failed to create data directory", slog.String("data", p.Data), slog.String("error", err.Error()))
return err
}
}
@ -84,7 +95,11 @@ func (p *Profile) Validate() error {
p.Data = dataDir
if p.Driver == "sqlite" && p.DSN == "" {
dbFile := fmt.Sprintf("memos_%s.db", p.Mode)
mode := "prod"
if p.Demo {
mode = "demo"
}
dbFile := fmt.Sprintf("memos_%s.db", mode)
p.DSN = filepath.Join(dataDir, dbFile)
}

View File

@ -9,15 +9,9 @@ import (
// Version is the service current released version.
// Semantic versioning: https://semver.org/
var Version = "0.25.3"
var Version = "0.26.0"
// DevVersion is the service current development version.
var DevVersion = "0.25.3"
func GetCurrentVersion(mode string) string {
if mode == "dev" || mode == "demo" {
return DevVersion
}
func GetCurrentVersion() string {
return Version
}

View File

@ -1,11 +1,11 @@
package email
import (
"sync"
"testing"
"time"
"github.com/stretchr/testify/assert"
"golang.org/x/sync/errgroup"
)
func TestSend(t *testing.T) {
@ -106,34 +106,22 @@ func TestSendAsyncConcurrent(t *testing.T) {
FromEmail: "test@example.com",
}
// Send multiple emails concurrently
var wg sync.WaitGroup
g := errgroup.Group{}
count := 5
for i := 0; i < count; i++ {
wg.Add(1)
go func() {
defer wg.Done()
g.Go(func() error {
message := &Message{
To: []string{"recipient@example.com"},
Subject: "Concurrent Test",
Body: "Test body",
}
SendAsync(config, message)
}()
return nil
})
}
// Should complete without deadlock
done := make(chan bool)
go func() {
wg.Wait()
done <- true
}()
select {
case <-done:
// Success
case <-time.After(1 * time.Second):
t.Fatal("SendAsync calls did not complete in time")
if err := g.Wait(); err != nil {
t.Fatalf("SendAsync calls failed: %v", err)
}
}

View File

@ -114,3 +114,46 @@ type FunctionValue struct {
}
func (*FunctionValue) isValueExpr() {}
// ListComprehensionCondition represents CEL macros like exists(), all(), filter().
type ListComprehensionCondition struct {
Kind ComprehensionKind
Field string // The list field to iterate over (e.g., "tags")
IterVar string // The iteration variable name (e.g., "t")
Predicate PredicateExpr // The predicate to evaluate for each element
}
func (*ListComprehensionCondition) isCondition() {}
// ComprehensionKind enumerates the types of list comprehensions.
type ComprehensionKind string
const (
ComprehensionExists ComprehensionKind = "exists"
)
// PredicateExpr represents predicates used in comprehensions.
type PredicateExpr interface {
isPredicateExpr()
}
// StartsWithPredicate represents t.startsWith("prefix").
type StartsWithPredicate struct {
Prefix string
}
func (*StartsWithPredicate) isPredicateExpr() {}
// EndsWithPredicate represents t.endsWith("suffix").
type EndsWithPredicate struct {
Suffix string
}
func (*EndsWithPredicate) isPredicateExpr() {}
// ContainsPredicate represents t.contains("substring").
type ContainsPredicate struct {
Substring string
}
func (*ContainsPredicate) isPredicateExpr() {}

View File

@ -36,6 +36,8 @@ func buildCondition(expr *exprv1.Expr, schema Schema) (Condition, error) {
return nil, errors.Errorf("identifier %q is not boolean", name)
}
return &FieldPredicateCondition{Field: name}, nil
case *exprv1.Expr_ComprehensionExpr:
return buildComprehensionCondition(v.ComprehensionExpr, schema)
default:
return nil, errors.New("unsupported top-level expression")
}
@ -415,3 +417,170 @@ func evaluateNumeric(expr *exprv1.Expr) (int64, bool, error) {
func timeNowUnix() int64 {
return time.Now().Unix()
}
// buildComprehensionCondition handles CEL comprehension expressions (exists, all, etc.).
func buildComprehensionCondition(comp *exprv1.Expr_Comprehension, schema Schema) (Condition, error) {
// Determine the comprehension kind by examining the loop initialization and step
kind, err := detectComprehensionKind(comp)
if err != nil {
return nil, err
}
// Get the field being iterated over
iterRangeIdent := comp.IterRange.GetIdentExpr()
if iterRangeIdent == nil {
return nil, errors.New("comprehension range must be a field identifier")
}
fieldName := iterRangeIdent.GetName()
// Validate the field
field, ok := schema.Field(fieldName)
if !ok {
return nil, errors.Errorf("unknown field %q in comprehension", fieldName)
}
if field.Kind != FieldKindJSONList {
return nil, errors.Errorf("field %q does not support comprehension (must be a list)", fieldName)
}
// Extract the predicate from the loop step
predicate, err := extractPredicate(comp, schema)
if err != nil {
return nil, err
}
return &ListComprehensionCondition{
Kind: kind,
Field: fieldName,
IterVar: comp.IterVar,
Predicate: predicate,
}, nil
}
// detectComprehensionKind determines if this is an exists() macro.
// Only exists() is currently supported.
func detectComprehensionKind(comp *exprv1.Expr_Comprehension) (ComprehensionKind, error) {
// Check the accumulator initialization
accuInit := comp.AccuInit.GetConstExpr()
if accuInit == nil {
return "", errors.New("comprehension accumulator must be initialized with a constant")
}
// exists() starts with false and uses OR (||) in loop step
if !accuInit.GetBoolValue() {
if step := comp.LoopStep.GetCallExpr(); step != nil && step.Function == "_||_" {
return ComprehensionExists, nil
}
}
// all() starts with true and uses AND (&&) - not supported
if accuInit.GetBoolValue() {
if step := comp.LoopStep.GetCallExpr(); step != nil && step.Function == "_&&_" {
return "", errors.New("all() comprehension is not supported; use exists() instead")
}
}
return "", errors.New("unsupported comprehension type; only exists() is supported")
}
// extractPredicate extracts the predicate expression from the comprehension loop step.
func extractPredicate(comp *exprv1.Expr_Comprehension, _ Schema) (PredicateExpr, error) {
// The loop step is: @result || predicate(t) for exists
// or: @result && predicate(t) for all
step := comp.LoopStep.GetCallExpr()
if step == nil {
return nil, errors.New("comprehension loop step must be a call expression")
}
if len(step.Args) != 2 {
return nil, errors.New("comprehension loop step must have two arguments")
}
// The predicate is the second argument
predicateExpr := step.Args[1]
predicateCall := predicateExpr.GetCallExpr()
if predicateCall == nil {
return nil, errors.New("comprehension predicate must be a function call")
}
// Handle different predicate functions
switch predicateCall.Function {
case "startsWith":
return buildStartsWithPredicate(predicateCall, comp.IterVar)
case "endsWith":
return buildEndsWithPredicate(predicateCall, comp.IterVar)
case "contains":
return buildContainsPredicate(predicateCall, comp.IterVar)
default:
return nil, errors.Errorf("unsupported predicate function %q in comprehension (supported: startsWith, endsWith, contains)", predicateCall.Function)
}
}
// buildStartsWithPredicate extracts the pattern from t.startsWith("prefix").
func buildStartsWithPredicate(call *exprv1.Expr_Call, iterVar string) (PredicateExpr, error) {
// Verify the target is the iteration variable
if target := call.Target.GetIdentExpr(); target == nil || target.GetName() != iterVar {
return nil, errors.Errorf("startsWith target must be the iteration variable %q", iterVar)
}
if len(call.Args) != 1 {
return nil, errors.New("startsWith expects exactly one argument")
}
prefix, err := getConstValue(call.Args[0])
if err != nil {
return nil, errors.Wrap(err, "startsWith argument must be a constant string")
}
prefixStr, ok := prefix.(string)
if !ok {
return nil, errors.New("startsWith argument must be a string")
}
return &StartsWithPredicate{Prefix: prefixStr}, nil
}
// buildEndsWithPredicate extracts the pattern from t.endsWith("suffix").
func buildEndsWithPredicate(call *exprv1.Expr_Call, iterVar string) (PredicateExpr, error) {
if target := call.Target.GetIdentExpr(); target == nil || target.GetName() != iterVar {
return nil, errors.Errorf("endsWith target must be the iteration variable %q", iterVar)
}
if len(call.Args) != 1 {
return nil, errors.New("endsWith expects exactly one argument")
}
suffix, err := getConstValue(call.Args[0])
if err != nil {
return nil, errors.Wrap(err, "endsWith argument must be a constant string")
}
suffixStr, ok := suffix.(string)
if !ok {
return nil, errors.New("endsWith argument must be a string")
}
return &EndsWithPredicate{Suffix: suffixStr}, nil
}
// buildContainsPredicate extracts the pattern from t.contains("substring").
func buildContainsPredicate(call *exprv1.Expr_Call, iterVar string) (PredicateExpr, error) {
if target := call.Target.GetIdentExpr(); target == nil || target.GetName() != iterVar {
return nil, errors.Errorf("contains target must be the iteration variable %q", iterVar)
}
if len(call.Args) != 1 {
return nil, errors.New("contains expects exactly one argument")
}
substring, err := getConstValue(call.Args[0])
if err != nil {
return nil, errors.Wrap(err, "contains argument must be a constant string")
}
substringStr, ok := substring.(string)
if !ok {
return nil, errors.New("contains argument must be a string")
}
return &ContainsPredicate{Substring: substringStr}, nil
}

View File

@ -74,6 +74,8 @@ func (r *renderer) renderCondition(cond Condition) (renderResult, error) {
return r.renderElementInCondition(c)
case *ContainsCondition:
return r.renderContainsCondition(c)
case *ListComprehensionCondition:
return r.renderListComprehension(c)
case *ConstantCondition:
if c.Value {
return renderResult{trivial: true}, nil
@ -461,13 +463,108 @@ func (r *renderer) renderContainsCondition(cond *ContainsCondition) (renderResul
}
}
func (r *renderer) renderListComprehension(cond *ListComprehensionCondition) (renderResult, error) {
field, ok := r.schema.Field(cond.Field)
if !ok {
return renderResult{}, errors.Errorf("unknown field %q", cond.Field)
}
if field.Kind != FieldKindJSONList {
return renderResult{}, errors.Errorf("field %q is not a JSON list", cond.Field)
}
// Render based on predicate type
switch pred := cond.Predicate.(type) {
case *StartsWithPredicate:
return r.renderTagStartsWith(field, pred.Prefix, cond.Kind)
case *EndsWithPredicate:
return r.renderTagEndsWith(field, pred.Suffix, cond.Kind)
case *ContainsPredicate:
return r.renderTagContains(field, pred.Substring, cond.Kind)
default:
return renderResult{}, errors.Errorf("unsupported predicate type %T in comprehension", pred)
}
}
// renderTagStartsWith generates SQL for tags.exists(t, t.startsWith("prefix")).
func (r *renderer) renderTagStartsWith(field Field, prefix string, _ ComprehensionKind) (renderResult, error) {
arrayExpr := jsonArrayExpr(r.dialect, field)
switch r.dialect {
case DialectSQLite, DialectMySQL:
// Match exact tag or tags with this prefix (hierarchical support)
exactMatch := r.buildJSONArrayLike(arrayExpr, fmt.Sprintf(`%%"%s"%%`, prefix))
prefixMatch := r.buildJSONArrayLike(arrayExpr, fmt.Sprintf(`%%"%s%%`, prefix))
condition := fmt.Sprintf("(%s OR %s)", exactMatch, prefixMatch)
return renderResult{sql: r.wrapWithNullCheck(arrayExpr, condition)}, nil
case DialectPostgres:
// Use PostgreSQL's powerful JSON operators
exactMatch := fmt.Sprintf("%s @> jsonb_build_array(%s::json)", arrayExpr, r.addArg(fmt.Sprintf(`"%s"`, prefix)))
prefixMatch := fmt.Sprintf("(%s)::text LIKE %s", arrayExpr, r.addArg(fmt.Sprintf(`%%"%s%%`, prefix)))
condition := fmt.Sprintf("(%s OR %s)", exactMatch, prefixMatch)
return renderResult{sql: r.wrapWithNullCheck(arrayExpr, condition)}, nil
default:
return renderResult{}, errors.Errorf("unsupported dialect %s", r.dialect)
}
}
// renderTagEndsWith generates SQL for tags.exists(t, t.endsWith("suffix")).
func (r *renderer) renderTagEndsWith(field Field, suffix string, _ ComprehensionKind) (renderResult, error) {
arrayExpr := jsonArrayExpr(r.dialect, field)
pattern := fmt.Sprintf(`%%%s"%%`, suffix)
likeExpr := r.buildJSONArrayLike(arrayExpr, pattern)
return renderResult{sql: r.wrapWithNullCheck(arrayExpr, likeExpr)}, nil
}
// renderTagContains generates SQL for tags.exists(t, t.contains("substring")).
func (r *renderer) renderTagContains(field Field, substring string, _ ComprehensionKind) (renderResult, error) {
arrayExpr := jsonArrayExpr(r.dialect, field)
pattern := fmt.Sprintf(`%%%s%%`, substring)
likeExpr := r.buildJSONArrayLike(arrayExpr, pattern)
return renderResult{sql: r.wrapWithNullCheck(arrayExpr, likeExpr)}, nil
}
// buildJSONArrayLike builds a LIKE expression for matching within a JSON array.
// Returns the LIKE clause without NULL/empty checks.
func (r *renderer) buildJSONArrayLike(arrayExpr, pattern string) string {
switch r.dialect {
case DialectSQLite, DialectMySQL:
return fmt.Sprintf("%s LIKE %s", arrayExpr, r.addArg(pattern))
case DialectPostgres:
return fmt.Sprintf("(%s)::text LIKE %s", arrayExpr, r.addArg(pattern))
default:
return ""
}
}
// wrapWithNullCheck wraps a condition with NULL and empty array checks.
// This ensures we don't match against NULL or empty JSON arrays.
func (r *renderer) wrapWithNullCheck(arrayExpr, condition string) string {
var nullCheck string
switch r.dialect {
case DialectSQLite:
nullCheck = fmt.Sprintf("%s IS NOT NULL AND %s != '[]'", arrayExpr, arrayExpr)
case DialectMySQL:
nullCheck = fmt.Sprintf("%s IS NOT NULL AND JSON_LENGTH(%s) > 0", arrayExpr, arrayExpr)
case DialectPostgres:
nullCheck = fmt.Sprintf("%s IS NOT NULL AND jsonb_array_length(%s) > 0", arrayExpr, arrayExpr)
default:
return condition
}
return fmt.Sprintf("(%s AND %s)", condition, nullCheck)
}
func (r *renderer) jsonBoolPredicate(field Field) (string, error) {
expr := jsonExtractExpr(r.dialect, field)
switch r.dialect {
case DialectSQLite:
return fmt.Sprintf("%s IS TRUE", expr), nil
case DialectMySQL:
return fmt.Sprintf("%s = CAST('true' AS JSON)", expr), nil
return fmt.Sprintf("COALESCE(%s, CAST('false' AS JSON)) = CAST('true' AS JSON)", expr), nil
case DialectPostgres:
return fmt.Sprintf("(%s)::boolean IS TRUE", expr), nil
default:

View File

@ -256,7 +256,7 @@ func NewAttachmentSchema() Schema {
Name: "filename",
Kind: FieldKindScalar,
Type: FieldTypeString,
Column: Column{Table: "resource", Name: "filename"},
Column: Column{Table: "attachment", Name: "filename"},
SupportsContains: true,
Expressions: map[DialectName]string{},
},
@ -264,14 +264,14 @@ func NewAttachmentSchema() Schema {
Name: "mime_type",
Kind: FieldKindScalar,
Type: FieldTypeString,
Column: Column{Table: "resource", Name: "type"},
Column: Column{Table: "attachment", Name: "type"},
Expressions: map[DialectName]string{},
},
"create_time": {
Name: "create_time",
Kind: FieldKindScalar,
Type: FieldTypeTimestamp,
Column: Column{Table: "resource", Name: "created_ts"},
Column: Column{Table: "attachment", Name: "created_ts"},
Expressions: map[DialectName]string{
// MySQL stores created_ts as TIMESTAMP, needs conversion to epoch
DialectMySQL: "UNIX_TIMESTAMP(%s)",
@ -284,7 +284,7 @@ func NewAttachmentSchema() Schema {
Name: "memo_id",
Kind: FieldKindScalar,
Type: FieldTypeInt,
Column: Column{Table: "resource", Name: "memo_id"},
Column: Column{Table: "attachment", Name: "memo_id"},
Expressions: map[DialectName]string{},
AllowedComparisonOps: map[ComparisonOperator]bool{
CompareEq: true,

View File

@ -2,6 +2,7 @@ syntax = "proto3";
package memos.api.v1;
import "api/v1/user_service.proto";
import "google/api/annotations.proto";
import "google/api/client.proto";
import "google/api/field_behavior.proto";
@ -34,18 +35,18 @@ service InstanceService {
// Instance profile message containing basic instance information.
message InstanceProfile {
// The name of instance owner.
// Format: users/{user}
string owner = 1;
// Version is the current version of instance.
string version = 2;
// Mode is the instance mode (e.g. "prod", "dev" or "demo").
string mode = 3;
// Demo indicates if the instance is in demo mode.
bool demo = 3;
// Instance URL is the URL of the instance.
string instance_url = 6;
// The first administrator who set up this instance.
// When null, instance requires initial setup (creating the first admin account).
User admin = 7;
}
// Request for instance profile.

View File

@ -173,11 +173,13 @@ message Memo {
(google.api.resource_reference) = {type: "memos.api.v1/User"}
];
// Output only. The creation timestamp.
google.protobuf.Timestamp create_time = 4 [(google.api.field_behavior) = OUTPUT_ONLY];
// The creation timestamp.
// If not set on creation, the server will set it to the current time.
google.protobuf.Timestamp create_time = 4 [(google.api.field_behavior) = OPTIONAL];
// Output only. The last update timestamp.
google.protobuf.Timestamp update_time = 5 [(google.api.field_behavior) = OUTPUT_ONLY];
// The last update timestamp.
// If not set on creation, the server will set it to the current time.
google.protobuf.Timestamp update_time = 5 [(google.api.field_behavior) = OPTIONAL];
// The display timestamp of the memo.
google.protobuf.Timestamp display_time = 6 [(google.api.field_behavior) = OPTIONAL];

View File

@ -203,13 +203,10 @@ message User {
// User role enumeration.
enum Role {
// Unspecified role.
ROLE_UNSPECIFIED = 0;
// Host role with full system access.
HOST = 1;
// Admin role with administrative privileges.
// Admin role with system access.
ADMIN = 2;
// Regular user role.
// User role with limited access.
USER = 3;
}
}

View File

@ -138,15 +138,15 @@ func (InstanceSetting_StorageSetting_StorageType) EnumDescriptor() ([]byte, []in
// Instance profile message containing basic instance information.
type InstanceProfile struct {
state protoimpl.MessageState `protogen:"open.v1"`
// The name of instance owner.
// Format: users/{user}
Owner string `protobuf:"bytes,1,opt,name=owner,proto3" json:"owner,omitempty"`
// Version is the current version of instance.
Version string `protobuf:"bytes,2,opt,name=version,proto3" json:"version,omitempty"`
// Mode is the instance mode (e.g. "prod", "dev" or "demo").
Mode string `protobuf:"bytes,3,opt,name=mode,proto3" json:"mode,omitempty"`
// Demo indicates if the instance is in demo mode.
Demo bool `protobuf:"varint,3,opt,name=demo,proto3" json:"demo,omitempty"`
// Instance URL is the URL of the instance.
InstanceUrl string `protobuf:"bytes,6,opt,name=instance_url,json=instanceUrl,proto3" json:"instance_url,omitempty"`
InstanceUrl string `protobuf:"bytes,6,opt,name=instance_url,json=instanceUrl,proto3" json:"instance_url,omitempty"`
// The first administrator who set up this instance.
// When null, instance requires initial setup (creating the first admin account).
Admin *User `protobuf:"bytes,7,opt,name=admin,proto3" json:"admin,omitempty"`
unknownFields protoimpl.UnknownFields
sizeCache protoimpl.SizeCache
}
@ -181,13 +181,6 @@ func (*InstanceProfile) Descriptor() ([]byte, []int) {
return file_api_v1_instance_service_proto_rawDescGZIP(), []int{0}
}
func (x *InstanceProfile) GetOwner() string {
if x != nil {
return x.Owner
}
return ""
}
func (x *InstanceProfile) GetVersion() string {
if x != nil {
return x.Version
@ -195,11 +188,11 @@ func (x *InstanceProfile) GetVersion() string {
return ""
}
func (x *InstanceProfile) GetMode() string {
func (x *InstanceProfile) GetDemo() bool {
if x != nil {
return x.Mode
return x.Demo
}
return ""
return false
}
func (x *InstanceProfile) GetInstanceUrl() string {
@ -209,6 +202,13 @@ func (x *InstanceProfile) GetInstanceUrl() string {
return ""
}
func (x *InstanceProfile) GetAdmin() *User {
if x != nil {
return x.Admin
}
return nil
}
// Request for instance profile.
type GetInstanceProfileRequest struct {
state protoimpl.MessageState `protogen:"open.v1"`
@ -875,12 +875,12 @@ var File_api_v1_instance_service_proto protoreflect.FileDescriptor
const file_api_v1_instance_service_proto_rawDesc = "" +
"\n" +
"\x1dapi/v1/instance_service.proto\x12\fmemos.api.v1\x1a\x1cgoogle/api/annotations.proto\x1a\x17google/api/client.proto\x1a\x1fgoogle/api/field_behavior.proto\x1a\x19google/api/resource.proto\x1a google/protobuf/field_mask.proto\"x\n" +
"\x0fInstanceProfile\x12\x14\n" +
"\x05owner\x18\x01 \x01(\tR\x05owner\x12\x18\n" +
"\x1dapi/v1/instance_service.proto\x12\fmemos.api.v1\x1a\x19api/v1/user_service.proto\x1a\x1cgoogle/api/annotations.proto\x1a\x17google/api/client.proto\x1a\x1fgoogle/api/field_behavior.proto\x1a\x19google/api/resource.proto\x1a google/protobuf/field_mask.proto\"\x8c\x01\n" +
"\x0fInstanceProfile\x12\x18\n" +
"\aversion\x18\x02 \x01(\tR\aversion\x12\x12\n" +
"\x04mode\x18\x03 \x01(\tR\x04mode\x12!\n" +
"\finstance_url\x18\x06 \x01(\tR\vinstanceUrl\"\x1b\n" +
"\x04demo\x18\x03 \x01(\bR\x04demo\x12!\n" +
"\finstance_url\x18\x06 \x01(\tR\vinstanceUrl\x12(\n" +
"\x05admin\x18\a \x01(\v2\x12.memos.api.v1.UserR\x05admin\"\x1b\n" +
"\x19GetInstanceProfileRequest\"\x99\x0f\n" +
"\x0fInstanceSetting\x12\x17\n" +
"\x04name\x18\x01 \x01(\tB\x03\xe0A\bR\x04name\x12W\n" +
@ -970,28 +970,30 @@ var file_api_v1_instance_service_proto_goTypes = []any{
(*InstanceSetting_MemoRelatedSetting)(nil), // 9: memos.api.v1.InstanceSetting.MemoRelatedSetting
(*InstanceSetting_GeneralSetting_CustomProfile)(nil), // 10: memos.api.v1.InstanceSetting.GeneralSetting.CustomProfile
(*InstanceSetting_StorageSetting_S3Config)(nil), // 11: memos.api.v1.InstanceSetting.StorageSetting.S3Config
(*fieldmaskpb.FieldMask)(nil), // 12: google.protobuf.FieldMask
(*User)(nil), // 12: memos.api.v1.User
(*fieldmaskpb.FieldMask)(nil), // 13: google.protobuf.FieldMask
}
var file_api_v1_instance_service_proto_depIdxs = []int32{
7, // 0: memos.api.v1.InstanceSetting.general_setting:type_name -> memos.api.v1.InstanceSetting.GeneralSetting
8, // 1: memos.api.v1.InstanceSetting.storage_setting:type_name -> memos.api.v1.InstanceSetting.StorageSetting
9, // 2: memos.api.v1.InstanceSetting.memo_related_setting:type_name -> memos.api.v1.InstanceSetting.MemoRelatedSetting
4, // 3: memos.api.v1.UpdateInstanceSettingRequest.setting:type_name -> memos.api.v1.InstanceSetting
12, // 4: memos.api.v1.UpdateInstanceSettingRequest.update_mask:type_name -> google.protobuf.FieldMask
10, // 5: memos.api.v1.InstanceSetting.GeneralSetting.custom_profile:type_name -> memos.api.v1.InstanceSetting.GeneralSetting.CustomProfile
1, // 6: memos.api.v1.InstanceSetting.StorageSetting.storage_type:type_name -> memos.api.v1.InstanceSetting.StorageSetting.StorageType
11, // 7: memos.api.v1.InstanceSetting.StorageSetting.s3_config:type_name -> memos.api.v1.InstanceSetting.StorageSetting.S3Config
3, // 8: memos.api.v1.InstanceService.GetInstanceProfile:input_type -> memos.api.v1.GetInstanceProfileRequest
5, // 9: memos.api.v1.InstanceService.GetInstanceSetting:input_type -> memos.api.v1.GetInstanceSettingRequest
6, // 10: memos.api.v1.InstanceService.UpdateInstanceSetting:input_type -> memos.api.v1.UpdateInstanceSettingRequest
2, // 11: memos.api.v1.InstanceService.GetInstanceProfile:output_type -> memos.api.v1.InstanceProfile
4, // 12: memos.api.v1.InstanceService.GetInstanceSetting:output_type -> memos.api.v1.InstanceSetting
4, // 13: memos.api.v1.InstanceService.UpdateInstanceSetting:output_type -> memos.api.v1.InstanceSetting
11, // [11:14] is the sub-list for method output_type
8, // [8:11] is the sub-list for method input_type
8, // [8:8] is the sub-list for extension type_name
8, // [8:8] is the sub-list for extension extendee
0, // [0:8] is the sub-list for field type_name
12, // 0: memos.api.v1.InstanceProfile.admin:type_name -> memos.api.v1.User
7, // 1: memos.api.v1.InstanceSetting.general_setting:type_name -> memos.api.v1.InstanceSetting.GeneralSetting
8, // 2: memos.api.v1.InstanceSetting.storage_setting:type_name -> memos.api.v1.InstanceSetting.StorageSetting
9, // 3: memos.api.v1.InstanceSetting.memo_related_setting:type_name -> memos.api.v1.InstanceSetting.MemoRelatedSetting
4, // 4: memos.api.v1.UpdateInstanceSettingRequest.setting:type_name -> memos.api.v1.InstanceSetting
13, // 5: memos.api.v1.UpdateInstanceSettingRequest.update_mask:type_name -> google.protobuf.FieldMask
10, // 6: memos.api.v1.InstanceSetting.GeneralSetting.custom_profile:type_name -> memos.api.v1.InstanceSetting.GeneralSetting.CustomProfile
1, // 7: memos.api.v1.InstanceSetting.StorageSetting.storage_type:type_name -> memos.api.v1.InstanceSetting.StorageSetting.StorageType
11, // 8: memos.api.v1.InstanceSetting.StorageSetting.s3_config:type_name -> memos.api.v1.InstanceSetting.StorageSetting.S3Config
3, // 9: memos.api.v1.InstanceService.GetInstanceProfile:input_type -> memos.api.v1.GetInstanceProfileRequest
5, // 10: memos.api.v1.InstanceService.GetInstanceSetting:input_type -> memos.api.v1.GetInstanceSettingRequest
6, // 11: memos.api.v1.InstanceService.UpdateInstanceSetting:input_type -> memos.api.v1.UpdateInstanceSettingRequest
2, // 12: memos.api.v1.InstanceService.GetInstanceProfile:output_type -> memos.api.v1.InstanceProfile
4, // 13: memos.api.v1.InstanceService.GetInstanceSetting:output_type -> memos.api.v1.InstanceSetting
4, // 14: memos.api.v1.InstanceService.UpdateInstanceSetting:output_type -> memos.api.v1.InstanceSetting
12, // [12:15] is the sub-list for method output_type
9, // [9:12] is the sub-list for method input_type
9, // [9:9] is the sub-list for extension type_name
9, // [9:9] is the sub-list for extension extendee
0, // [0:9] is the sub-list for field type_name
}
func init() { file_api_v1_instance_service_proto_init() }
@ -999,6 +1001,7 @@ func file_api_v1_instance_service_proto_init() {
if File_api_v1_instance_service_proto != nil {
return
}
file_api_v1_user_service_proto_init()
file_api_v1_instance_service_proto_msgTypes[2].OneofWrappers = []any{
(*InstanceSetting_GeneralSetting_)(nil),
(*InstanceSetting_StorageSetting_)(nil),

View File

@ -222,9 +222,11 @@ type Memo struct {
// The name of the creator.
// Format: users/{user}
Creator string `protobuf:"bytes,3,opt,name=creator,proto3" json:"creator,omitempty"`
// Output only. The creation timestamp.
// The creation timestamp.
// If not set on creation, the server will set it to the current time.
CreateTime *timestamppb.Timestamp `protobuf:"bytes,4,opt,name=create_time,json=createTime,proto3" json:"create_time,omitempty"`
// Output only. The last update timestamp.
// The last update timestamp.
// If not set on creation, the server will set it to the current time.
UpdateTime *timestamppb.Timestamp `protobuf:"bytes,5,opt,name=update_time,json=updateTime,proto3" json:"update_time,omitempty"`
// The display timestamp of the memo.
DisplayTime *timestamppb.Timestamp `protobuf:"bytes,6,opt,name=display_time,json=displayTime,proto3" json:"display_time,omitempty"`
@ -1816,9 +1818,9 @@ const file_api_v1_memo_service_proto_rawDesc = "" +
"\x05state\x18\x02 \x01(\x0e2\x13.memos.api.v1.StateB\x03\xe0A\x02R\x05state\x123\n" +
"\acreator\x18\x03 \x01(\tB\x19\xe0A\x03\xfaA\x13\n" +
"\x11memos.api.v1/UserR\acreator\x12@\n" +
"\vcreate_time\x18\x04 \x01(\v2\x1a.google.protobuf.TimestampB\x03\xe0A\x03R\n" +
"\vcreate_time\x18\x04 \x01(\v2\x1a.google.protobuf.TimestampB\x03\xe0A\x01R\n" +
"createTime\x12@\n" +
"\vupdate_time\x18\x05 \x01(\v2\x1a.google.protobuf.TimestampB\x03\xe0A\x03R\n" +
"\vupdate_time\x18\x05 \x01(\v2\x1a.google.protobuf.TimestampB\x03\xe0A\x01R\n" +
"updateTime\x12B\n" +
"\fdisplay_time\x18\x06 \x01(\v2\x1a.google.protobuf.TimestampB\x03\xe0A\x01R\vdisplayTime\x12\x1d\n" +
"\acontent\x18\a \x01(\tB\x03\xe0A\x02R\acontent\x12=\n" +

View File

@ -29,13 +29,10 @@ const (
type User_Role int32
const (
// Unspecified role.
User_ROLE_UNSPECIFIED User_Role = 0
// Host role with full system access.
User_HOST User_Role = 1
// Admin role with administrative privileges.
// Admin role with system access.
User_ADMIN User_Role = 2
// Regular user role.
// User role with limited access.
User_USER User_Role = 3
)
@ -43,13 +40,11 @@ const (
var (
User_Role_name = map[int32]string{
0: "ROLE_UNSPECIFIED",
1: "HOST",
2: "ADMIN",
3: "USER",
}
User_Role_value = map[string]int32{
"ROLE_UNSPECIFIED": 0,
"HOST": 1,
"ADMIN": 2,
"USER": 3,
}
@ -2509,7 +2504,7 @@ var File_api_v1_user_service_proto protoreflect.FileDescriptor
const file_api_v1_user_service_proto_rawDesc = "" +
"\n" +
"\x19api/v1/user_service.proto\x12\fmemos.api.v1\x1a\x13api/v1/common.proto\x1a\x1cgoogle/api/annotations.proto\x1a\x17google/api/client.proto\x1a\x1fgoogle/api/field_behavior.proto\x1a\x19google/api/resource.proto\x1a\x1bgoogle/protobuf/empty.proto\x1a google/protobuf/field_mask.proto\x1a\x1fgoogle/protobuf/timestamp.proto\"\xcb\x04\n" +
"\x19api/v1/user_service.proto\x12\fmemos.api.v1\x1a\x13api/v1/common.proto\x1a\x1cgoogle/api/annotations.proto\x1a\x17google/api/client.proto\x1a\x1fgoogle/api/field_behavior.proto\x1a\x19google/api/resource.proto\x1a\x1bgoogle/protobuf/empty.proto\x1a google/protobuf/field_mask.proto\x1a\x1fgoogle/protobuf/timestamp.proto\"\xc1\x04\n" +
"\x04User\x12\x17\n" +
"\x04name\x18\x01 \x01(\tB\x03\xe0A\bR\x04name\x120\n" +
"\x04role\x18\x02 \x01(\x0e2\x17.memos.api.v1.User.RoleB\x03\xe0A\x02R\x04role\x12\x1f\n" +
@ -2525,10 +2520,9 @@ const file_api_v1_user_service_proto_rawDesc = "" +
" \x01(\v2\x1a.google.protobuf.TimestampB\x03\xe0A\x03R\n" +
"createTime\x12@\n" +
"\vupdate_time\x18\v \x01(\v2\x1a.google.protobuf.TimestampB\x03\xe0A\x03R\n" +
"updateTime\";\n" +
"updateTime\"1\n" +
"\x04Role\x12\x14\n" +
"\x10ROLE_UNSPECIFIED\x10\x00\x12\b\n" +
"\x04HOST\x10\x01\x12\t\n" +
"\x10ROLE_UNSPECIFIED\x10\x00\x12\t\n" +
"\x05ADMIN\x10\x02\x12\b\n" +
"\x04USER\x10\x03:7\xeaA4\n" +
"\x11memos.api.v1/User\x12\fusers/{user}\x1a\x04name*\x05users2\x04user\"\x9d\x01\n" +

View File

@ -2132,20 +2132,21 @@ components:
InstanceProfile:
type: object
properties:
owner:
type: string
description: |-
The name of instance owner.
Format: users/{user}
version:
type: string
description: Version is the current version of instance.
mode:
type: string
description: Mode is the instance mode (e.g. "prod", "dev" or "demo").
demo:
type: boolean
description: Demo indicates if the instance is in demo mode.
instanceUrl:
type: string
description: Instance URL is the URL of the instance.
admin:
allOf:
- $ref: '#/components/schemas/User'
description: |-
The first administrator who set up this instance.
When null, instance requires initial setup (creating the first admin account).
description: Instance profile message containing basic instance information.
InstanceSetting:
type: object
@ -2470,14 +2471,16 @@ components:
The name of the creator.
Format: users/{user}
createTime:
readOnly: true
type: string
description: Output only. The creation timestamp.
description: |-
The creation timestamp.
If not set on creation, the server will set it to the current time.
format: date-time
updateTime:
readOnly: true
type: string
description: Output only. The last update timestamp.
description: |-
The last update timestamp.
If not set on creation, the server will set it to the current time.
format: date-time
displayTime:
type: string
@ -2857,7 +2860,6 @@ components:
role:
enum:
- ROLE_UNSPECIFIED
- HOST
- ADMIN
- USER
type: string

View File

@ -1,30 +1,56 @@
FROM golang:1.25-alpine AS backend
FROM --platform=$BUILDPLATFORM golang:1.25-alpine AS backend
WORKDIR /backend-build
# Install build dependencies
RUN apk add --no-cache git ca-certificates
# Copy go mod files and download dependencies (cached layer)
COPY go.mod go.sum ./
RUN go mod download
RUN --mount=type=cache,target=/go/pkg/mod \
go mod download
# Copy source code (use .dockerignore to exclude unnecessary files)
COPY . .
# Please build frontend first, so that the static files are available.
# Refer to `pnpm release` in package.json for the build command.
ARG TARGETOS TARGETARCH VERSION COMMIT
RUN --mount=type=cache,target=/go/pkg/mod \
--mount=type=cache,target=/root/.cache/go-build \
go build -ldflags="-s -w" -o memos ./cmd/memos
CGO_ENABLED=0 GOOS=$TARGETOS GOARCH=$TARGETARCH \
go build \
-trimpath \
-ldflags="-s -w -extldflags '-static'" \
-tags netgo,osusergo \
-o memos \
./cmd/memos
FROM alpine:latest AS monolithic
WORKDIR /usr/local/memos
# Use minimal Alpine with security updates
FROM alpine:3.21 AS monolithic
RUN apk add --no-cache tzdata
ENV TZ="UTC"
# Install runtime dependencies and create non-root user in single layer
RUN apk add --no-cache tzdata ca-certificates su-exec && \
addgroup -g 10001 -S nonroot && \
adduser -u 10001 -S -G nonroot -h /var/opt/memos nonroot && \
mkdir -p /var/opt/memos /usr/local/memos && \
chown -R nonroot:nonroot /var/opt/memos
COPY --from=backend /backend-build/memos /usr/local/memos/
COPY ./scripts/entrypoint.sh /usr/local/memos/
# Copy binary and entrypoint to /usr/local/memos
COPY --from=backend /backend-build/memos /usr/local/memos/memos
COPY --from=backend --chmod=755 /backend-build/scripts/entrypoint.sh /usr/local/memos/entrypoint.sh
# Run as root to fix permissions, entrypoint will drop to nonroot
USER root
# Set working directory to the writable volume
WORKDIR /var/opt/memos
# Data directory
VOLUME /var/opt/memos
ENV TZ="UTC" \
MEMOS_PORT="5230"
EXPOSE 5230
# Directory to store the data, which can be referenced as the mounting point.
RUN mkdir -p /var/opt/memos
VOLUME /var/opt/memos
ENV MEMOS_MODE="prod"
ENV MEMOS_PORT="5230"
ENTRYPOINT ["./entrypoint.sh", "./memos"]
ENTRYPOINT ["/usr/local/memos/entrypoint.sh", "/usr/local/memos/memos"]

View File

@ -29,4 +29,4 @@ go build -o "$OUTPUT" ./cmd/memos
echo "Build successful!"
echo "To run the application, execute the following command:"
echo "$OUTPUT --mode dev"
echo "$OUTPUT"

View File

@ -1,6 +1,6 @@
services:
memos:
image: neosmemo/memos:latest
image: neosmemo/memos:stable
container_name: memos
volumes:
- ~/.memos/:/var/opt/memos

View File

@ -1,5 +1,19 @@
#!/usr/bin/env sh
# Fix ownership of data directory for users upgrading from older versions
# where files were created as root
MEMOS_UID=${MEMOS_UID:-10001}
MEMOS_GID=${MEMOS_GID:-10001}
DATA_DIR="/var/opt/memos"
if [ "$(id -u)" = "0" ]; then
# Running as root, fix permissions and drop to nonroot
if [ -d "$DATA_DIR" ]; then
chown -R "$MEMOS_UID:$MEMOS_GID" "$DATA_DIR" 2>/dev/null || true
fi
exec su-exec "$MEMOS_UID:$MEMOS_GID" "$0" "$@"
fi
file_env() {
var="$1"
fileVar="${var}_FILE"

View File

@ -74,7 +74,7 @@ func TestParseAccessTokenV2(t *testing.T) {
})
t.Run("parses token with different roles", func(t *testing.T) {
roles := []string{"USER", "ADMIN", "HOST"}
roles := []string{"USER", "ADMIN"}
for _, role := range roles {
token, _, err := GenerateAccessTokenV2(1, "testuser", role, "ACTIVE", secret)
require.NoError(t, err)

View File

@ -29,8 +29,9 @@ var PublicMethods = map[string]struct{}{
"/memos.api.v1.IdentityProviderService/ListIdentityProviders": {},
// Memo Service - public memos (visibility filtering done in service layer)
"/memos.api.v1.MemoService/GetMemo": {},
"/memos.api.v1.MemoService/ListMemos": {},
"/memos.api.v1.MemoService/GetMemo": {},
"/memos.api.v1.MemoService/ListMemos": {},
"/memos.api.v1.MemoService/ListMemoComments": {},
}
// IsPublicMethod checks if a procedure path is public (no authentication required).

View File

@ -0,0 +1,191 @@
package v1
import (
"bytes"
"image"
"image/color"
"image/jpeg"
"testing"
"github.com/disintegration/imaging"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestShouldStripExif(t *testing.T) {
t.Parallel()
tests := []struct {
name string
mimeType string
expected bool
}{
{
name: "JPEG should strip EXIF",
mimeType: "image/jpeg",
expected: true,
},
{
name: "JPG should strip EXIF",
mimeType: "image/jpg",
expected: true,
},
{
name: "TIFF should strip EXIF",
mimeType: "image/tiff",
expected: true,
},
{
name: "WebP should strip EXIF",
mimeType: "image/webp",
expected: true,
},
{
name: "HEIC should strip EXIF",
mimeType: "image/heic",
expected: true,
},
{
name: "HEIF should strip EXIF",
mimeType: "image/heif",
expected: true,
},
{
name: "PNG should not strip EXIF",
mimeType: "image/png",
expected: false,
},
{
name: "GIF should not strip EXIF",
mimeType: "image/gif",
expected: false,
},
{
name: "text file should not strip EXIF",
mimeType: "text/plain",
expected: false,
},
{
name: "PDF should not strip EXIF",
mimeType: "application/pdf",
expected: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
result := shouldStripExif(tt.mimeType)
assert.Equal(t, tt.expected, result)
})
}
}
func TestStripImageExif(t *testing.T) {
t.Parallel()
// Create a simple test image
img := image.NewRGBA(image.Rect(0, 0, 100, 100))
// Fill with red color
for y := 0; y < 100; y++ {
for x := 0; x < 100; x++ {
img.Set(x, y, color.RGBA{R: 255, G: 0, B: 0, A: 255})
}
}
// Encode as JPEG
var buf bytes.Buffer
err := jpeg.Encode(&buf, img, &jpeg.Options{Quality: 90})
require.NoError(t, err)
originalData := buf.Bytes()
t.Run("strip JPEG metadata", func(t *testing.T) {
t.Parallel()
strippedData, err := stripImageExif(originalData, "image/jpeg")
require.NoError(t, err)
assert.NotEmpty(t, strippedData)
// Verify it's still a valid image
decodedImg, err := imaging.Decode(bytes.NewReader(strippedData))
require.NoError(t, err)
assert.Equal(t, 100, decodedImg.Bounds().Dx())
assert.Equal(t, 100, decodedImg.Bounds().Dy())
})
t.Run("strip JPG metadata (alternate extension)", func(t *testing.T) {
t.Parallel()
strippedData, err := stripImageExif(originalData, "image/jpg")
require.NoError(t, err)
assert.NotEmpty(t, strippedData)
// Verify it's still a valid image
decodedImg, err := imaging.Decode(bytes.NewReader(strippedData))
require.NoError(t, err)
assert.NotNil(t, decodedImg)
})
t.Run("strip PNG metadata", func(t *testing.T) {
t.Parallel()
// Encode as PNG first
var pngBuf bytes.Buffer
err := imaging.Encode(&pngBuf, img, imaging.PNG)
require.NoError(t, err)
strippedData, err := stripImageExif(pngBuf.Bytes(), "image/png")
require.NoError(t, err)
assert.NotEmpty(t, strippedData)
// Verify it's still a valid image
decodedImg, err := imaging.Decode(bytes.NewReader(strippedData))
require.NoError(t, err)
assert.Equal(t, 100, decodedImg.Bounds().Dx())
assert.Equal(t, 100, decodedImg.Bounds().Dy())
})
t.Run("handle WebP format by converting to JPEG", func(t *testing.T) {
t.Parallel()
// WebP format will be converted to JPEG
strippedData, err := stripImageExif(originalData, "image/webp")
require.NoError(t, err)
assert.NotEmpty(t, strippedData)
// Verify it's a valid image
decodedImg, err := imaging.Decode(bytes.NewReader(strippedData))
require.NoError(t, err)
assert.NotNil(t, decodedImg)
})
t.Run("handle HEIC format by converting to JPEG", func(t *testing.T) {
t.Parallel()
strippedData, err := stripImageExif(originalData, "image/heic")
require.NoError(t, err)
assert.NotEmpty(t, strippedData)
// Verify it's a valid image
decodedImg, err := imaging.Decode(bytes.NewReader(strippedData))
require.NoError(t, err)
assert.NotNil(t, decodedImg)
})
t.Run("return error for invalid image data", func(t *testing.T) {
t.Parallel()
invalidData := []byte("not an image")
_, err := stripImageExif(invalidData, "image/jpeg")
assert.Error(t, err)
assert.Contains(t, err.Error(), "failed to decode image")
})
t.Run("return error for empty image data", func(t *testing.T) {
t.Parallel()
emptyData := []byte{}
_, err := stripImageExif(emptyData, "image/jpeg")
assert.Error(t, err)
})
}

View File

@ -6,6 +6,7 @@ import (
"encoding/binary"
"fmt"
"io"
"log/slog"
"mime"
"net/http"
"os"
@ -14,6 +15,7 @@ import (
"strings"
"time"
"github.com/disintegration/imaging"
"github.com/lithammer/shortuuid/v4"
"github.com/pkg/errors"
"google.golang.org/grpc/codes"
@ -38,6 +40,10 @@ const (
MebiByte = 1024 * 1024
// ThumbnailCacheFolder is the folder name where the thumbnail images are stored.
ThumbnailCacheFolder = ".thumbnail_cache"
// defaultJPEGQuality is the JPEG quality used when re-encoding images for EXIF stripping.
// Quality 95 maintains visual quality while ensuring metadata is removed.
defaultJPEGQuality = 95
)
var SupportedThumbnailMimeTypes = []string{
@ -45,6 +51,17 @@ var SupportedThumbnailMimeTypes = []string{
"image/jpeg",
}
// exifCapableImageTypes defines image formats that may contain EXIF metadata.
// These formats will have their EXIF metadata stripped on upload for privacy.
var exifCapableImageTypes = map[string]bool{
"image/jpeg": true,
"image/jpg": true,
"image/tiff": true,
"image/webp": true,
"image/heic": true,
"image/heif": true,
}
func (s *APIV1Service) CreateAttachment(ctx context.Context, request *v1pb.CreateAttachmentRequest) (*v1pb.Attachment, error) {
user, err := s.fetchCurrentUser(ctx)
if err != nil {
@ -111,6 +128,21 @@ func (s *APIV1Service) CreateAttachment(ctx context.Context, request *v1pb.Creat
create.Size = int64(size)
create.Blob = request.Attachment.Content
// Strip EXIF metadata from images for privacy protection.
// This removes sensitive information like GPS location, device details, etc.
if shouldStripExif(create.Type) {
if strippedBlob, err := stripImageExif(create.Blob, create.Type); err != nil {
// Log warning but continue with original image to ensure uploads don't fail.
slog.Warn("failed to strip EXIF metadata from image",
slog.String("type", create.Type),
slog.String("filename", create.Filename),
slog.String("error", err.Error()))
} else {
create.Blob = strippedBlob
create.Size = int64(len(strippedBlob))
}
}
if err := SaveAttachmentBlob(ctx, s.Profile, s.Store, create); err != nil {
return nil, status.Errorf(codes.Internal, "failed to save attachment blob: %v", err)
}
@ -214,6 +246,12 @@ func (s *APIV1Service) GetAttachment(ctx context.Context, request *v1pb.GetAttac
if attachment == nil {
return nil, status.Errorf(codes.NotFound, "attachment not found")
}
// Check access permission based on linked memo visibility.
if err := s.checkAttachmentAccess(ctx, attachment); err != nil {
return nil, err
}
return convertAttachmentFromStore(attachment), nil
}
@ -225,10 +263,24 @@ func (s *APIV1Service) UpdateAttachment(ctx context.Context, request *v1pb.Updat
if request.UpdateMask == nil || len(request.UpdateMask.Paths) == 0 {
return nil, status.Errorf(codes.InvalidArgument, "update mask is required")
}
user, err := s.fetchCurrentUser(ctx)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get current user: %v", err)
}
if user == nil {
return nil, status.Errorf(codes.Unauthenticated, "user not authenticated")
}
attachment, err := s.Store.GetAttachment(ctx, &store.FindAttachment{UID: &attachmentUID})
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get attachment: %v", err)
}
if attachment == nil {
return nil, status.Errorf(codes.NotFound, "attachment not found")
}
// Only the creator or admin can update the attachment.
if attachment.CreatorID != user.ID && !isSuperUser(user) {
return nil, status.Errorf(codes.PermissionDenied, "permission denied")
}
currentTs := time.Now().Unix()
update := &store.UpdateAttachment{
@ -516,3 +568,89 @@ func (s *APIV1Service) validateAttachmentFilter(ctx context.Context, filterStr s
}
return nil
}
// checkAttachmentAccess verifies the user has permission to access the attachment.
// For unlinked attachments (no memo), only the creator can access.
// For linked attachments, access follows the memo's visibility rules.
func (s *APIV1Service) checkAttachmentAccess(ctx context.Context, attachment *store.Attachment) error {
user, _ := s.fetchCurrentUser(ctx)
// For unlinked attachments, only the creator can access.
if attachment.MemoID == nil {
if user == nil {
return status.Errorf(codes.Unauthenticated, "user not authenticated")
}
if attachment.CreatorID != user.ID && !isSuperUser(user) {
return status.Errorf(codes.PermissionDenied, "permission denied")
}
return nil
}
// For linked attachments, check memo visibility.
memo, err := s.Store.GetMemo(ctx, &store.FindMemo{ID: attachment.MemoID})
if err != nil {
return status.Errorf(codes.Internal, "failed to get memo: %v", err)
}
if memo == nil {
return status.Errorf(codes.NotFound, "memo not found")
}
if memo.Visibility == store.Public {
return nil
}
if user == nil {
return status.Errorf(codes.Unauthenticated, "user not authenticated")
}
if memo.Visibility == store.Private && memo.CreatorID != user.ID && !isSuperUser(user) {
return status.Errorf(codes.PermissionDenied, "permission denied")
}
return nil
}
// shouldStripExif checks if the MIME type is an image format that may contain EXIF metadata.
// Returns true for formats like JPEG, TIFF, WebP, HEIC, and HEIF which commonly contain
// privacy-sensitive metadata such as GPS coordinates, camera settings, and device information.
func shouldStripExif(mimeType string) bool {
return exifCapableImageTypes[mimeType]
}
// stripImageExif removes EXIF metadata from image files by decoding and re-encoding them.
// This prevents exposure of sensitive metadata such as GPS location, camera details, and timestamps.
//
// The function preserves the correct image orientation by applying EXIF orientation tags
// during decoding before stripping all metadata. Images are re-encoded with high quality
// to minimize visual degradation.
//
// Supported formats:
// - JPEG/JPG: Re-encoded as JPEG with quality 95
// - PNG: Re-encoded as PNG (lossless)
// - TIFF/WebP/HEIC/HEIF: Re-encoded as JPEG with quality 95
//
// Returns the cleaned image data without any EXIF metadata, or an error if processing fails.
func stripImageExif(imageData []byte, mimeType string) ([]byte, error) {
// Decode image with automatic EXIF orientation correction.
// This ensures the image displays correctly after metadata removal.
img, err := imaging.Decode(bytes.NewReader(imageData), imaging.AutoOrientation(true))
if err != nil {
return nil, errors.Wrap(err, "failed to decode image")
}
// Re-encode the image without EXIF metadata.
var buf bytes.Buffer
var encodeErr error
if mimeType == "image/png" {
// Preserve PNG format for lossless encoding
encodeErr = imaging.Encode(&buf, img, imaging.PNG)
} else {
// For JPEG, TIFF, WebP, HEIC, HEIF - re-encode as JPEG.
// This ensures EXIF is stripped and provides good compression.
encodeErr = imaging.Encode(&buf, img, imaging.JPEG, imaging.JPEGQuality(defaultJPEGQuality))
}
if encodeErr != nil {
return nil, errors.Wrap(encodeErr, "failed to encode image")
}
return buf.Bytes(), nil
}

View File

@ -64,5 +64,5 @@ func unmarshalPageToken(s string, pageToken *v1pb.PageToken) error {
}
func isSuperUser(user *store.User) bool {
return user.Role == store.RoleAdmin || user.Role == store.RoleHost
return user.Role == store.RoleAdmin
}

View File

@ -5,6 +5,7 @@ import (
"errors"
"fmt"
"log/slog"
"reflect"
"runtime/debug"
"connectrpc.com/connect"
@ -50,10 +51,30 @@ func (*MetadataInterceptor) WrapUnary(next connect.UnaryFunc) connect.UnaryFunc
// Set metadata in context so services can use metadata.FromIncomingContext()
ctx = metadata.NewIncomingContext(ctx, md)
return next(ctx, req)
// Execute the request
resp, err := next(ctx, req)
// Prevent browser caching of API responses to avoid stale data issues
// See: https://github.com/usememos/memos/issues/5470
if !isNilAnyResponse(resp) && resp.Header() != nil {
resp.Header().Set("Cache-Control", "no-cache, no-store, must-revalidate")
resp.Header().Set("Pragma", "no-cache")
resp.Header().Set("Expires", "0")
}
return resp, err
}
}
func isNilAnyResponse(resp connect.AnyResponse) bool {
if resp == nil {
return true
}
val := reflect.ValueOf(resp)
return val.Kind() == reflect.Ptr && val.IsNil()
}
func (*MetadataInterceptor) WrapStreamingClient(next connect.StreamingClientFunc) connect.StreamingClientFunc {
return next
}

View File

@ -18,7 +18,10 @@ func (s *APIV1Service) CreateIdentityProvider(ctx context.Context, request *v1pb
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get user: %v", err)
}
if currentUser == nil || currentUser.Role != store.RoleHost {
if currentUser == nil {
return nil, status.Errorf(codes.Unauthenticated, "user not authenticated")
}
if currentUser.Role != store.RoleAdmin {
return nil, status.Errorf(codes.PermissionDenied, "permission denied")
}
@ -84,7 +87,10 @@ func (s *APIV1Service) UpdateIdentityProvider(ctx context.Context, request *v1pb
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get user: %v", err)
}
if currentUser == nil || currentUser.Role != store.RoleHost {
if currentUser == nil {
return nil, status.Errorf(codes.Unauthenticated, "user not authenticated")
}
if currentUser.Role != store.RoleAdmin {
return nil, status.Errorf(codes.PermissionDenied, "permission denied")
}
@ -125,7 +131,10 @@ func (s *APIV1Service) DeleteIdentityProvider(ctx context.Context, request *v1pb
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get user: %v", err)
}
if currentUser == nil || currentUser.Role != store.RoleHost {
if currentUser == nil {
return nil, status.Errorf(codes.Unauthenticated, "user not authenticated")
}
if currentUser.Role != store.RoleAdmin {
return nil, status.Errorf(codes.PermissionDenied, "permission denied")
}
@ -219,7 +228,7 @@ func convertIdentityProviderConfigToStore(identityProviderType v1pb.IdentityProv
}
func redactIdentityProviderResponse(identityProvider *v1pb.IdentityProvider, userRole store.Role) *v1pb.IdentityProvider {
if userRole != store.RoleHost {
if userRole != store.RoleAdmin {
if identityProvider.Type == v1pb.IdentityProvider_OAUTH2 {
identityProvider.Config.GetOauth2Config().ClientSecret = ""
}

View File

@ -15,17 +15,16 @@ import (
// GetInstanceProfile returns the instance profile.
func (s *APIV1Service) GetInstanceProfile(ctx context.Context, _ *v1pb.GetInstanceProfileRequest) (*v1pb.InstanceProfile, error) {
admin, err := s.GetInstanceAdmin(ctx)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get instance admin: %v", err)
}
instanceProfile := &v1pb.InstanceProfile{
Version: s.Profile.Version,
Mode: s.Profile.Mode,
Demo: s.Profile.Demo,
InstanceUrl: s.Profile.InstanceURL,
}
owner, err := s.GetInstanceOwner(ctx)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get instance owner: %v", err)
}
if owner != nil {
instanceProfile.Owner = owner.Name
Admin: admin, // nil when not initialized
}
return instanceProfile, nil
}
@ -64,13 +63,16 @@ func (s *APIV1Service) GetInstanceSetting(ctx context.Context, request *v1pb.Get
return nil, status.Errorf(codes.NotFound, "instance setting not found")
}
// For storage setting, only host can get it.
// For storage setting, only admin can get it.
if instanceSetting.Key == storepb.InstanceSettingKey_STORAGE {
user, err := s.fetchCurrentUser(ctx)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get current user: %v", err)
}
if user == nil || user.Role != store.RoleHost {
if user == nil {
return nil, status.Errorf(codes.Unauthenticated, "user not authenticated")
}
if user.Role != store.RoleAdmin {
return nil, status.Errorf(codes.PermissionDenied, "permission denied")
}
}
@ -86,7 +88,7 @@ func (s *APIV1Service) UpdateInstanceSetting(ctx context.Context, request *v1pb.
if user == nil {
return nil, status.Errorf(codes.Unauthenticated, "user not authenticated")
}
if user.Role != store.RoleHost {
if user.Role != store.RoleAdmin {
return nil, status.Errorf(codes.PermissionDenied, "permission denied")
}
@ -267,24 +269,17 @@ func convertInstanceMemoRelatedSettingToStore(setting *v1pb.InstanceSetting_Memo
}
}
var ownerCache *v1pb.User
func (s *APIV1Service) GetInstanceOwner(ctx context.Context) (*v1pb.User, error) {
if ownerCache != nil {
return ownerCache, nil
}
hostUserType := store.RoleHost
func (s *APIV1Service) GetInstanceAdmin(ctx context.Context) (*v1pb.User, error) {
adminUserType := store.RoleAdmin
user, err := s.Store.GetUser(ctx, &store.FindUser{
Role: &hostUserType,
Role: &adminUserType,
})
if err != nil {
return nil, errors.Wrapf(err, "failed to find owner")
return nil, errors.Wrapf(err, "failed to find admin")
}
if user == nil {
return nil, nil
}
ownerCache = convertUserFromStore(user)
return ownerCache, nil
return convertUserFromStore(user), nil
}

View File

@ -98,6 +98,24 @@ func (s *APIV1Service) ListMemoAttachments(ctx context.Context, request *v1pb.Li
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get memo: %v", err)
}
if memo == nil {
return nil, status.Errorf(codes.NotFound, "memo not found")
}
// Check memo visibility.
if memo.Visibility != store.Public {
user, err := s.fetchCurrentUser(ctx)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get user: %v", err)
}
if user == nil {
return nil, status.Errorf(codes.Unauthenticated, "user not authenticated")
}
if memo.Visibility == store.Private && memo.CreatorID != user.ID && !isSuperUser(user) {
return nil, status.Errorf(codes.PermissionDenied, "permission denied")
}
}
attachments, err := s.Store.ListAttachments(ctx, &store.FindAttachment{
MemoID: &memo.ID,
})

View File

@ -45,10 +45,35 @@ func (s *APIV1Service) CreateMemo(ctx context.Context, request *v1pb.CreateMemoR
Content: request.Memo.Content,
Visibility: convertVisibilityToStore(request.Memo.Visibility),
}
instanceMemoRelatedSetting, err := s.Store.GetInstanceMemoRelatedSetting(ctx)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get instance memo related setting")
}
// Handle display_time first: if provided, use it to set the appropriate timestamp
// based on the instance setting (similar to UpdateMemo logic)
// Note: explicit create_time/update_time below will override this if provided
if request.Memo.DisplayTime != nil && request.Memo.DisplayTime.IsValid() {
displayTs := request.Memo.DisplayTime.AsTime().Unix()
if instanceMemoRelatedSetting.DisplayWithUpdateTime {
create.UpdatedTs = displayTs
} else {
create.CreatedTs = displayTs
}
}
// Set custom timestamps if provided in the request
// These take precedence over display_time
if request.Memo.CreateTime != nil && request.Memo.CreateTime.IsValid() {
createdTs := request.Memo.CreateTime.AsTime().Unix()
create.CreatedTs = createdTs
}
if request.Memo.UpdateTime != nil && request.Memo.UpdateTime.IsValid() {
updatedTs := request.Memo.UpdateTime.AsTime().Unix()
create.UpdatedTs = updatedTs
}
if instanceMemoRelatedSetting.DisallowPublicVisibility && create.Visibility == store.Public {
return nil, status.Errorf(codes.PermissionDenied, "disable public memos system setting is enabled")
}
@ -281,7 +306,7 @@ func (s *APIV1Service) GetMemo(ctx context.Context, request *v1pb.GetMemoRequest
return nil, status.Errorf(codes.Internal, "failed to get user")
}
if user == nil {
return nil, status.Errorf(codes.PermissionDenied, "permission denied")
return nil, status.Errorf(codes.Unauthenticated, "user not authenticated")
}
if memo.Visibility == store.Private && memo.CreatorID != user.ID {
return nil, status.Errorf(codes.PermissionDenied, "permission denied")
@ -497,23 +522,7 @@ func (s *APIV1Service) DeleteMemo(ctx context.Context, request *v1pb.DeleteMemoR
}
}
if err = s.Store.DeleteMemo(ctx, &store.DeleteMemo{ID: memo.ID}); err != nil {
return nil, status.Errorf(codes.Internal, "failed to delete memo")
}
// Delete memo relation
if err := s.Store.DeleteMemoRelation(ctx, &store.DeleteMemoRelation{MemoID: &memo.ID}); err != nil {
return nil, status.Errorf(codes.Internal, "failed to delete memo relations")
}
// Delete related attachments.
for _, attachment := range attachments {
if err := s.Store.DeleteAttachment(ctx, &store.DeleteAttachment{ID: attachment.ID}); err != nil {
return nil, status.Errorf(codes.Internal, "failed to delete attachment")
}
}
// Delete memo comments
// Delete memo comments first (store.DeleteMemo handles their relations and attachments)
commentType := store.MemoRelationComment
relations, err := s.Store.ListMemoRelations(ctx, &store.FindMemoRelation{RelatedMemoID: &memo.ID, Type: &commentType})
if err != nil {
@ -525,10 +534,9 @@ func (s *APIV1Service) DeleteMemo(ctx context.Context, request *v1pb.DeleteMemoR
}
}
// Delete memo references
referenceType := store.MemoRelationReference
if err := s.Store.DeleteMemoRelation(ctx, &store.DeleteMemoRelation{RelatedMemoID: &memo.ID, Type: &referenceType}); err != nil {
return nil, status.Errorf(codes.Internal, "failed to delete memo references")
// Delete the memo (store.DeleteMemo handles relation and attachment cleanup)
if err = s.Store.DeleteMemo(ctx, &store.DeleteMemo{ID: memo.ID}); err != nil {
return nil, status.Errorf(codes.Internal, "failed to delete memo")
}
return &emptypb.Empty{}, nil
@ -543,6 +551,21 @@ func (s *APIV1Service) CreateMemoComment(ctx context.Context, request *v1pb.Crea
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get memo")
}
if relatedMemo == nil {
return nil, status.Errorf(codes.NotFound, "memo not found")
}
// Check memo visibility before allowing comment.
user, err := s.fetchCurrentUser(ctx)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get user")
}
if user == nil {
return nil, status.Errorf(codes.Unauthenticated, "user not authenticated")
}
if relatedMemo.Visibility == store.Private && relatedMemo.CreatorID != user.ID && !isSuperUser(user) {
return nil, status.Errorf(codes.PermissionDenied, "permission denied")
}
// Create the memo comment first.
memoComment, err := s.CreateMemo(ctx, &v1pb.CreateMemoRequest{

View File

@ -15,6 +15,33 @@ import (
)
func (s *APIV1Service) ListMemoReactions(ctx context.Context, request *v1pb.ListMemoReactionsRequest) (*v1pb.ListMemoReactionsResponse, error) {
// Extract memo UID and check visibility.
memoUID, err := ExtractMemoUIDFromName(request.Name)
if err != nil {
return nil, status.Errorf(codes.InvalidArgument, "invalid memo name: %v", err)
}
memo, err := s.Store.GetMemo(ctx, &store.FindMemo{UID: &memoUID})
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get memo: %v", err)
}
if memo == nil {
return nil, status.Errorf(codes.NotFound, "memo not found")
}
// Check memo visibility.
if memo.Visibility != store.Public {
user, err := s.fetchCurrentUser(ctx)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get user")
}
if user == nil {
return nil, status.Errorf(codes.Unauthenticated, "user not authenticated")
}
if memo.Visibility == store.Private && memo.CreatorID != user.ID && !isSuperUser(user) {
return nil, status.Errorf(codes.PermissionDenied, "permission denied")
}
}
reactions, err := s.Store.ListReactions(ctx, &store.FindReaction{
ContentID: &request.Name,
})
@ -40,6 +67,25 @@ func (s *APIV1Service) UpsertMemoReaction(ctx context.Context, request *v1pb.Ups
if user == nil {
return nil, status.Errorf(codes.Unauthenticated, "user not authenticated")
}
// Extract memo UID and check visibility before allowing reaction.
memoUID, err := ExtractMemoUIDFromName(request.Reaction.ContentId)
if err != nil {
return nil, status.Errorf(codes.InvalidArgument, "invalid memo name: %v", err)
}
memo, err := s.Store.GetMemo(ctx, &store.FindMemo{UID: &memoUID})
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get memo: %v", err)
}
if memo == nil {
return nil, status.Errorf(codes.NotFound, "memo not found")
}
// Check memo visibility.
if memo.Visibility == store.Private && memo.CreatorID != user.ID && !isSuperUser(user) {
return nil, status.Errorf(codes.PermissionDenied, "permission denied")
}
reaction, err := s.Store.UpsertReaction(ctx, &store.Reaction{
CreatorID: user.ID,
ContentID: request.Reaction.ContentId,

View File

@ -97,7 +97,7 @@ func TestCreateIdentityProvider(t *testing.T) {
_, err := ts.Service.CreateIdentityProvider(ctx, req)
require.Error(t, err)
require.Contains(t, err.Error(), "permission denied")
require.Contains(t, err.Error(), "user not authenticated")
})
}
@ -547,6 +547,6 @@ func TestIdentityProviderPermissions(t *testing.T) {
_, err := ts.Service.CreateIdentityProvider(ctx, req)
require.Error(t, err)
require.Contains(t, err.Error(), "permission denied")
require.Contains(t, err.Error(), "user not authenticated")
})
}

View File

@ -0,0 +1,54 @@
package test
import (
"context"
"testing"
"github.com/stretchr/testify/require"
v1pb "github.com/usememos/memos/proto/gen/api/v1"
)
func TestInstanceAdminRetrieval(t *testing.T) {
ctx := context.Background()
t.Run("Instance becomes initialized after first admin user is created", func(t *testing.T) {
// Create test service
ts := NewTestService(t)
defer ts.Cleanup()
// Verify instance is not initialized initially
profile1, err := ts.Service.GetInstanceProfile(ctx, &v1pb.GetInstanceProfileRequest{})
require.NoError(t, err)
require.Nil(t, profile1.Admin, "Instance should not be initialized before first admin user")
// Create the first admin user
user, err := ts.CreateHostUser(ctx, "admin")
require.NoError(t, err)
require.NotNil(t, user)
// Verify instance is now initialized
profile2, err := ts.Service.GetInstanceProfile(ctx, &v1pb.GetInstanceProfileRequest{})
require.NoError(t, err)
require.NotNil(t, profile2.Admin, "Instance should be initialized after first admin user is created")
require.Equal(t, user.Username, profile2.Admin.Username)
})
t.Run("Admin retrieval is cached by Store layer", func(t *testing.T) {
// Create test service
ts := NewTestService(t)
defer ts.Cleanup()
// Create admin user
user, err := ts.CreateHostUser(ctx, "admin")
require.NoError(t, err)
// Multiple calls should return consistent admin user (from cache)
for i := 0; i < 5; i++ {
profile, err := ts.Service.GetInstanceProfile(ctx, &v1pb.GetInstanceProfileRequest{})
require.NoError(t, err)
require.NotNil(t, profile.Admin)
require.Equal(t, user.Username, profile.Admin.Username)
}
})
}

View File

@ -2,7 +2,6 @@ package test
import (
"context"
"fmt"
"testing"
"github.com/stretchr/testify/require"
@ -28,14 +27,14 @@ func TestGetInstanceProfile(t *testing.T) {
// Verify the response contains expected data
require.Equal(t, "test-1.0.0", resp.Version)
require.Equal(t, "dev", resp.Mode)
require.True(t, resp.Demo)
require.Equal(t, "http://localhost:8080", resp.InstanceUrl)
// Owner should be empty since no users are created
require.Empty(t, resp.Owner)
// Instance should not be initialized since no admin users are created
require.Nil(t, resp.Admin)
})
t.Run("GetInstanceProfile with owner", func(t *testing.T) {
t.Run("GetInstanceProfile with initialized instance", func(t *testing.T) {
// Create test service for this specific test
ts := NewTestService(t)
defer ts.Cleanup()
@ -53,14 +52,14 @@ func TestGetInstanceProfile(t *testing.T) {
require.NoError(t, err)
require.NotNil(t, resp)
// Verify the response contains expected data including owner
// Verify the response contains expected data with initialized flag
require.Equal(t, "test-1.0.0", resp.Version)
require.Equal(t, "dev", resp.Mode)
require.True(t, resp.Demo)
require.Equal(t, "http://localhost:8080", resp.InstanceUrl)
// User name should be "users/{id}" format where id is the user's ID
expectedOwnerName := fmt.Sprintf("users/%d", hostUser.ID)
require.Equal(t, expectedOwnerName, resp.Owner)
// Instance should be initialized since an admin user exists
require.NotNil(t, resp.Admin)
require.Equal(t, hostUser.Username, resp.Admin.Username)
})
}
@ -73,9 +72,8 @@ func TestGetInstanceProfile_Concurrency(t *testing.T) {
defer ts.Cleanup()
// Create a host user
hostUser, err := ts.CreateHostUser(ctx, "admin")
_, err := ts.CreateHostUser(ctx, "admin")
require.NoError(t, err)
expectedOwnerName := fmt.Sprintf("users/%d", hostUser.ID)
// Make concurrent requests
numGoroutines := 10
@ -102,9 +100,9 @@ func TestGetInstanceProfile_Concurrency(t *testing.T) {
case resp := <-results:
require.NotNil(t, resp)
require.Equal(t, "test-1.0.0", resp.Version)
require.Equal(t, "dev", resp.Mode)
require.True(t, resp.Demo)
require.Equal(t, "http://localhost:8080", resp.InstanceUrl)
require.Equal(t, expectedOwnerName, resp.Owner)
require.NotNil(t, resp.Admin)
}
}
})

View File

@ -5,8 +5,10 @@ import (
"fmt"
"slices"
"testing"
"time"
"github.com/stretchr/testify/require"
"google.golang.org/protobuf/types/known/timestamppb"
apiv1 "github.com/usememos/memos/proto/gen/api/v1"
)
@ -250,3 +252,118 @@ func TestListMemos(t *testing.T) {
require.NotNil(t, userTwoReaction)
require.Equal(t, "👍", userTwoReaction.ReactionType)
}
// TestCreateMemoWithCustomTimestamps tests that custom timestamps can be set when creating memos and comments.
// This addresses issue #5483: https://github.com/usememos/memos/issues/5483
func TestCreateMemoWithCustomTimestamps(t *testing.T) {
ctx := context.Background()
ts := NewTestService(t)
defer ts.Cleanup()
// Create a test user
user, err := ts.CreateRegularUser(ctx, "test-user-timestamps")
require.NoError(t, err)
require.NotNil(t, user)
userCtx := ts.CreateUserContext(ctx, user.ID)
// Define custom timestamps (January 1, 2020)
customCreateTime := time.Date(2020, 1, 1, 12, 0, 0, 0, time.UTC)
customUpdateTime := time.Date(2020, 1, 2, 12, 0, 0, 0, time.UTC)
customDisplayTime := time.Date(2020, 1, 3, 12, 0, 0, 0, time.UTC)
// Test 1: Create a memo with custom create_time
memoWithCreateTime, err := ts.Service.CreateMemo(userCtx, &apiv1.CreateMemoRequest{
Memo: &apiv1.Memo{
Content: "This memo has a custom creation time",
Visibility: apiv1.Visibility_PRIVATE,
CreateTime: timestamppb.New(customCreateTime),
},
})
require.NoError(t, err)
require.NotNil(t, memoWithCreateTime)
require.Equal(t, customCreateTime.Unix(), memoWithCreateTime.CreateTime.AsTime().Unix(), "create_time should match the custom timestamp")
// Test 2: Create a memo with custom update_time
memoWithUpdateTime, err := ts.Service.CreateMemo(userCtx, &apiv1.CreateMemoRequest{
Memo: &apiv1.Memo{
Content: "This memo has a custom update time",
Visibility: apiv1.Visibility_PRIVATE,
UpdateTime: timestamppb.New(customUpdateTime),
},
})
require.NoError(t, err)
require.NotNil(t, memoWithUpdateTime)
require.Equal(t, customUpdateTime.Unix(), memoWithUpdateTime.UpdateTime.AsTime().Unix(), "update_time should match the custom timestamp")
// Test 3: Create a memo with custom display_time
// Note: display_time is computed from either created_ts or updated_ts based on instance setting
// Since DisplayWithUpdateTime defaults to false, display_time maps to created_ts
memoWithDisplayTime, err := ts.Service.CreateMemo(userCtx, &apiv1.CreateMemoRequest{
Memo: &apiv1.Memo{
Content: "This memo has a custom display time",
Visibility: apiv1.Visibility_PRIVATE,
DisplayTime: timestamppb.New(customDisplayTime),
},
})
require.NoError(t, err)
require.NotNil(t, memoWithDisplayTime)
// Since DisplayWithUpdateTime is false by default, display_time sets created_ts
require.Equal(t, customDisplayTime.Unix(), memoWithDisplayTime.DisplayTime.AsTime().Unix(), "display_time should match the custom timestamp")
require.Equal(t, customDisplayTime.Unix(), memoWithDisplayTime.CreateTime.AsTime().Unix(), "create_time should also match since display_time maps to created_ts")
// Test 4: Create a memo with all custom timestamps
// When both display_time and create_time are provided, create_time takes precedence
memoWithAllTimestamps, err := ts.Service.CreateMemo(userCtx, &apiv1.CreateMemoRequest{
Memo: &apiv1.Memo{
Content: "This memo has all custom timestamps",
Visibility: apiv1.Visibility_PRIVATE,
CreateTime: timestamppb.New(customCreateTime),
UpdateTime: timestamppb.New(customUpdateTime),
DisplayTime: timestamppb.New(customDisplayTime),
},
})
require.NoError(t, err)
require.NotNil(t, memoWithAllTimestamps)
require.Equal(t, customCreateTime.Unix(), memoWithAllTimestamps.CreateTime.AsTime().Unix(), "create_time should match the custom timestamp")
require.Equal(t, customUpdateTime.Unix(), memoWithAllTimestamps.UpdateTime.AsTime().Unix(), "update_time should match the custom timestamp")
// display_time is computed from created_ts when DisplayWithUpdateTime is false
require.Equal(t, customCreateTime.Unix(), memoWithAllTimestamps.DisplayTime.AsTime().Unix(), "display_time should be derived from create_time")
// Test 5: Create a comment (memo relation) with custom timestamps
parentMemo, err := ts.Service.CreateMemo(userCtx, &apiv1.CreateMemoRequest{
Memo: &apiv1.Memo{
Content: "This is the parent memo",
Visibility: apiv1.Visibility_PRIVATE,
},
})
require.NoError(t, err)
require.NotNil(t, parentMemo)
customCommentCreateTime := time.Date(2021, 6, 15, 10, 30, 0, 0, time.UTC)
comment, err := ts.Service.CreateMemoComment(userCtx, &apiv1.CreateMemoCommentRequest{
Name: parentMemo.Name,
Comment: &apiv1.Memo{
Content: "This is a comment with custom create time",
Visibility: apiv1.Visibility_PRIVATE,
CreateTime: timestamppb.New(customCommentCreateTime),
},
})
require.NoError(t, err)
require.NotNil(t, comment)
require.Equal(t, customCommentCreateTime.Unix(), comment.CreateTime.AsTime().Unix(), "comment create_time should match the custom timestamp")
// Test 6: Verify that memos without custom timestamps still get auto-generated ones
memoWithoutTimestamps, err := ts.Service.CreateMemo(userCtx, &apiv1.CreateMemoRequest{
Memo: &apiv1.Memo{
Content: "This memo has auto-generated timestamps",
Visibility: apiv1.Visibility_PRIVATE,
},
})
require.NoError(t, err)
require.NotNil(t, memoWithoutTimestamps)
require.NotNil(t, memoWithoutTimestamps.CreateTime, "create_time should be auto-generated")
require.NotNil(t, memoWithoutTimestamps.UpdateTime, "update_time should be auto-generated")
require.True(t, time.Now().Unix()-memoWithoutTimestamps.CreateTime.AsTime().Unix() < 5, "create_time should be recent (within 5 seconds)")
}

View File

@ -29,7 +29,7 @@ func NewTestService(t *testing.T) *TestService {
// Create a test profile
testProfile := &profile.Profile{
Mode: "dev",
Demo: true,
Version: "test-1.0.0",
InstanceURL: "http://localhost:8080",
Driver: "sqlite",
@ -56,17 +56,16 @@ func NewTestService(t *testing.T) *TestService {
}
}
// Cleanup clears caches and closes resources after test.
// Cleanup closes resources after test.
func (ts *TestService) Cleanup() {
ts.Store.Close()
// Note: Owner cache is package-level in parent package, cannot clear from test package
}
// CreateHostUser creates a host user for testing.
// CreateHostUser creates an admin user for testing.
func (ts *TestService) CreateHostUser(ctx context.Context, username string) (*store.User, error) {
return ts.Store.CreateUser(ctx, &store.User{
Username: username,
Role: store.RoleHost,
Role: store.RoleAdmin,
Email: username + "@example.com",
})
}

View File

@ -37,7 +37,7 @@ func (s *APIV1Service) ListUsers(ctx context.Context, request *v1pb.ListUsersReq
if currentUser == nil {
return nil, status.Errorf(codes.Unauthenticated, "user not authenticated")
}
if currentUser.Role != store.RoleHost && currentUser.Role != store.RoleAdmin {
if currentUser.Role != store.RoleAdmin {
return nil, status.Errorf(codes.PermissionDenied, "permission denied")
}
@ -132,17 +132,17 @@ func (s *APIV1Service) CreateUser(ctx context.Context, request *v1pb.CreateUserR
// Determine the role to assign
var roleToAssign store.Role
if isFirstUser {
// First-time setup: create the first user as HOST (no authentication required)
roleToAssign = store.RoleHost
} else if currentUser != nil && currentUser.Role == store.RoleHost {
// Authenticated HOST user can create users with any role specified in request
// First-time setup: create the first user as ADMIN (no authentication required)
roleToAssign = store.RoleAdmin
} else if currentUser != nil && currentUser.Role == store.RoleAdmin {
// Authenticated ADMIN user can create users with any role specified in request
if request.User.Role != v1pb.User_ROLE_UNSPECIFIED {
roleToAssign = convertUserRoleToStore(request.User.Role)
} else {
roleToAssign = store.RoleUser
}
} else {
// Unauthenticated or non-HOST users can only create normal users
// Unauthenticated or non-ADMIN users can only create normal users
roleToAssign = store.RoleUser
}
@ -192,9 +192,12 @@ func (s *APIV1Service) UpdateUser(ctx context.Context, request *v1pb.UpdateUserR
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get user: %v", err)
}
if currentUser == nil {
return nil, status.Errorf(codes.Unauthenticated, "user not authenticated")
}
// Check permission.
// Only allow admin or self to update user.
if currentUser.ID != userID && currentUser.Role != store.RoleAdmin && currentUser.Role != store.RoleHost {
if currentUser.ID != userID && currentUser.Role != store.RoleAdmin {
return nil, status.Errorf(codes.PermissionDenied, "permission denied")
}
@ -261,7 +264,7 @@ func (s *APIV1Service) UpdateUser(ctx context.Context, request *v1pb.UpdateUserR
update.Description = &request.User.Description
case "role":
// Only allow admin to update role.
if currentUser.Role != store.RoleAdmin && currentUser.Role != store.RoleHost {
if currentUser.Role != store.RoleAdmin {
return nil, status.Errorf(codes.PermissionDenied, "permission denied")
}
role := convertUserRoleToStore(request.User.Role)
@ -298,7 +301,10 @@ func (s *APIV1Service) DeleteUser(ctx context.Context, request *v1pb.DeleteUserR
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get user: %v", err)
}
if currentUser.ID != userID && currentUser.Role != store.RoleAdmin && currentUser.Role != store.RoleHost {
if currentUser == nil {
return nil, status.Errorf(codes.Unauthenticated, "user not authenticated")
}
if currentUser.ID != userID && currentUser.Role != store.RoleAdmin {
return nil, status.Errorf(codes.PermissionDenied, "permission denied")
}
@ -539,7 +545,7 @@ func (s *APIV1Service) ListPersonalAccessTokens(ctx context.Context, request *v1
claims := auth.GetUserClaims(ctx)
if claims == nil || claims.UserID != userID {
currentUser, _ := s.fetchCurrentUser(ctx)
if currentUser == nil || (currentUser.ID != userID && currentUser.Role != store.RoleHost && currentUser.Role != store.RoleAdmin) {
if currentUser == nil || (currentUser.ID != userID && currentUser.Role != store.RoleAdmin) {
return nil, status.Errorf(codes.PermissionDenied, "permission denied")
}
}
@ -686,7 +692,7 @@ func (s *APIV1Service) ListUserWebhooks(ctx context.Context, request *v1pb.ListU
if currentUser == nil {
return nil, status.Errorf(codes.Unauthenticated, "user not authenticated")
}
if currentUser.ID != userID && currentUser.Role != store.RoleHost && currentUser.Role != store.RoleAdmin {
if currentUser.ID != userID && currentUser.Role != store.RoleAdmin {
return nil, status.Errorf(codes.PermissionDenied, "permission denied")
}
@ -718,7 +724,7 @@ func (s *APIV1Service) CreateUserWebhook(ctx context.Context, request *v1pb.Crea
if currentUser == nil {
return nil, status.Errorf(codes.Unauthenticated, "user not authenticated")
}
if currentUser.ID != userID && currentUser.Role != store.RoleHost && currentUser.Role != store.RoleAdmin {
if currentUser.ID != userID && currentUser.Role != store.RoleAdmin {
return nil, status.Errorf(codes.PermissionDenied, "permission denied")
}
@ -758,7 +764,7 @@ func (s *APIV1Service) UpdateUserWebhook(ctx context.Context, request *v1pb.Upda
if currentUser == nil {
return nil, status.Errorf(codes.Unauthenticated, "user not authenticated")
}
if currentUser.ID != userID && currentUser.Role != store.RoleHost && currentUser.Role != store.RoleAdmin {
if currentUser.ID != userID && currentUser.Role != store.RoleAdmin {
return nil, status.Errorf(codes.PermissionDenied, "permission denied")
}
@ -830,7 +836,7 @@ func (s *APIV1Service) DeleteUserWebhook(ctx context.Context, request *v1pb.Dele
if currentUser == nil {
return nil, status.Errorf(codes.Unauthenticated, "user not authenticated")
}
if currentUser.ID != userID && currentUser.Role != store.RoleHost && currentUser.Role != store.RoleAdmin {
if currentUser.ID != userID && currentUser.Role != store.RoleAdmin {
return nil, status.Errorf(codes.PermissionDenied, "permission denied")
}
@ -925,8 +931,6 @@ func convertUserFromStore(user *store.User) *v1pb.User {
func convertUserRoleFromStore(role store.Role) v1pb.User_Role {
switch role {
case store.RoleHost:
return v1pb.User_HOST
case store.RoleAdmin:
return v1pb.User_ADMIN
case store.RoleUser:
@ -938,8 +942,6 @@ func convertUserRoleFromStore(role store.Role) v1pb.User_Role {
func convertUserRoleToStore(role v1pb.User_Role) store.Role {
switch role {
case v1pb.User_HOST:
return store.RoleHost
case v1pb.User_ADMIN:
return store.RoleAdmin
default:
@ -1240,6 +1242,9 @@ func (s *APIV1Service) ListUserNotifications(ctx context.Context, request *v1pb.
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get current user: %v", err)
}
if currentUser == nil {
return nil, status.Errorf(codes.Unauthenticated, "user not authenticated")
}
if currentUser.ID != userID {
return nil, status.Errorf(codes.PermissionDenied, "permission denied")
}
@ -1287,6 +1292,9 @@ func (s *APIV1Service) UpdateUserNotification(ctx context.Context, request *v1pb
return nil, status.Errorf(codes.Internal, "failed to get current user: %v", err)
}
if currentUser == nil {
return nil, status.Errorf(codes.Unauthenticated, "user not authenticated")
}
// Verify ownership before updating
inboxes, err := s.Store.ListInboxes(ctx, &store.FindInbox{
ID: &notificationID,
@ -1352,6 +1360,9 @@ func (s *APIV1Service) DeleteUserNotification(ctx context.Context, request *v1pb
return nil, status.Errorf(codes.Internal, "failed to get current user: %v", err)
}
if currentUser == nil {
return nil, status.Errorf(codes.Unauthenticated, "user not authenticated")
}
// Verify ownership before deletion
inboxes, err := s.Store.ListInboxes(ctx, &store.FindInbox{
ID: &notificationID,

View File

@ -59,7 +59,7 @@ func (s *APIV1Service) RegisterGateway(ctx context.Context, echoServer *echo.Ech
ctx := r.Context()
// Get the RPC method name from context (set by grpc-gateway after routing)
rpcMethod, _ := runtime.RPCMethod(ctx)
rpcMethod, ok := runtime.RPCMethod(ctx)
// Extract credentials from HTTP headers
authHeader := r.Header.Get("Authorization")
@ -67,7 +67,8 @@ func (s *APIV1Service) RegisterGateway(ctx context.Context, echoServer *echo.Ech
result := authenticator.Authenticate(ctx, authHeader)
// Enforce authentication for non-public methods
if result == nil && !IsPublicMethod(rpcMethod) {
// If rpcMethod cannot be determined, allow through, service layer will handle visibility checks
if result == nil && ok && !IsPublicMethod(rpcMethod) {
http.Error(w, `{"code": 16, "message": "authentication required"}`, http.StatusUnauthorized)
return
}
@ -125,7 +126,7 @@ func (s *APIV1Service) RegisterGateway(ctx context.Context, echoServer *echo.Ech
gwGroup.Any("/file/*", handler)
// Connect handlers for browser clients (replaces grpc-web).
logStacktraces := s.Profile.IsDev()
logStacktraces := s.Profile.Demo
connectInterceptors := connect.WithInterceptors(
NewMetadataInterceptor(), // Convert HTTP headers to gRPC metadata first
NewLoggingInterceptor(logStacktraces),

View File

@ -26,13 +26,55 @@ import (
"github.com/usememos/memos/store"
)
// Constants for file serving configuration.
const (
// ThumbnailCacheFolder is the folder name where the thumbnail images are stored.
// ThumbnailCacheFolder is the folder name where thumbnail images are stored.
ThumbnailCacheFolder = ".thumbnail_cache"
// thumbnailMaxSize is the maximum size in pixels for the largest dimension of the thumbnail image.
// thumbnailMaxSize is the maximum dimension (width or height) for thumbnails.
thumbnailMaxSize = 600
// maxConcurrentThumbnails limits concurrent thumbnail generation to prevent memory exhaustion.
maxConcurrentThumbnails = 3
// cacheMaxAge is the max-age value for Cache-Control headers (1 hour).
cacheMaxAge = "public, max-age=3600"
)
// xssUnsafeTypes contains MIME types that could execute scripts if served directly.
// These are served as application/octet-stream to prevent XSS attacks.
var xssUnsafeTypes = map[string]bool{
"text/html": true,
"text/javascript": true,
"application/javascript": true,
"application/x-javascript": true,
"text/xml": true,
"application/xml": true,
"application/xhtml+xml": true,
"image/svg+xml": true,
}
// thumbnailSupportedTypes contains image MIME types that support thumbnail generation.
var thumbnailSupportedTypes = map[string]bool{
"image/png": true,
"image/jpeg": true,
"image/heic": true,
"image/heif": true,
"image/webp": true,
}
// avatarAllowedTypes contains MIME types allowed for user avatars.
var avatarAllowedTypes = map[string]bool{
"image/png": true,
"image/jpeg": true,
"image/jpg": true,
"image/gif": true,
"image/webp": true,
"image/heic": true,
"image/heif": true,
}
// SupportedThumbnailMimeTypes is the exported list of thumbnail-supported MIME types.
var SupportedThumbnailMimeTypes = []string{
"image/png",
"image/jpeg",
@ -41,15 +83,16 @@ var SupportedThumbnailMimeTypes = []string{
"image/webp",
}
// dataURIRegex parses data URI format: ...
var dataURIRegex = regexp.MustCompile(`^data:(?P<type>[^;]+);base64,(?P<base64>.+)`)
// FileServerService handles HTTP file serving with proper range request support.
// This service bypasses gRPC-Gateway to use native HTTP serving via http.ServeContent(),
// which is required for Safari video/audio playback.
type FileServerService struct {
Profile *profile.Profile
Store *store.Store
authenticator *auth.Authenticator
// thumbnailSemaphore limits concurrent thumbnail generation to prevent memory exhaustion
// thumbnailSemaphore limits concurrent thumbnail generation.
thumbnailSemaphore *semaphore.Weighted
}
@ -59,29 +102,27 @@ func NewFileServerService(profile *profile.Profile, store *store.Store, secret s
Profile: profile,
Store: store,
authenticator: auth.NewAuthenticator(store, secret),
thumbnailSemaphore: semaphore.NewWeighted(3), // Limit to 3 concurrent thumbnail generations
thumbnailSemaphore: semaphore.NewWeighted(maxConcurrentThumbnails),
}
}
// RegisterRoutes registers HTTP file serving routes.
func (s *FileServerService) RegisterRoutes(echoServer *echo.Echo) {
fileGroup := echoServer.Group("/file")
// Serve attachment binary files
fileGroup.GET("/attachments/:uid/:filename", s.serveAttachmentFile)
// Serve user avatar images
fileGroup.GET("/users/:identifier/avatar", s.serveUserAvatar)
}
// =============================================================================
// HTTP Handlers
// =============================================================================
// serveAttachmentFile serves attachment binary content using native HTTP.
// This properly handles range requests required by Safari for video/audio playback.
func (s *FileServerService) serveAttachmentFile(c echo.Context) error {
ctx := c.Request().Context()
uid := c.Param("uid")
thumbnail := c.QueryParam("thumbnail") == "true"
wantThumbnail := c.QueryParam("thumbnail") == "true"
// Get attachment from database
attachment, err := s.Store.GetAttachment(ctx, &store.FindAttachment{
UID: &uid,
GetBlob: true,
@ -93,96 +134,25 @@ func (s *FileServerService) serveAttachmentFile(c echo.Context) error {
return echo.NewHTTPError(http.StatusNotFound, "attachment not found")
}
// Check permissions - verify memo visibility if attachment belongs to a memo
if err := s.checkAttachmentPermission(ctx, c, attachment); err != nil {
return err
}
// Get the binary content
blob, err := s.getAttachmentBlob(attachment)
if err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, "failed to get attachment blob").SetInternal(err)
contentType := s.sanitizeContentType(attachment.Type)
// Stream video/audio to avoid loading entire file into memory.
if isMediaType(attachment.Type) {
return s.serveMediaStream(c, attachment, contentType)
}
// Handle thumbnail requests for images
if thumbnail && s.isImageType(attachment.Type) {
thumbnailBlob, err := s.getOrGenerateThumbnail(ctx, attachment)
if err != nil {
// Log warning but fall back to original image
c.Logger().Warnf("failed to get thumbnail: %v", err)
} else {
blob = thumbnailBlob
}
}
// Determine content type
contentType := attachment.Type
if strings.HasPrefix(contentType, "text/") {
contentType += "; charset=utf-8"
}
// Prevent XSS attacks by serving potentially unsafe files as octet-stream
unsafeTypes := []string{
"text/html",
"text/javascript",
"application/javascript",
"application/x-javascript",
"text/xml",
"application/xml",
"application/xhtml+xml",
"image/svg+xml",
}
for _, unsafeType := range unsafeTypes {
if strings.EqualFold(contentType, unsafeType) {
contentType = "application/octet-stream"
break
}
}
// Set common headers
c.Response().Header().Set("Content-Type", contentType)
c.Response().Header().Set("Cache-Control", "public, max-age=3600")
// Prevent MIME-type sniffing which could lead to XSS
c.Response().Header().Set("X-Content-Type-Options", "nosniff")
// Defense-in-depth: prevent embedding in frames and restrict content loading
c.Response().Header().Set("X-Frame-Options", "DENY")
c.Response().Header().Set("Content-Security-Policy", "default-src 'none'; style-src 'unsafe-inline';")
// Support HDR/wide color gamut display for capable browsers
if strings.HasPrefix(contentType, "image/") || strings.HasPrefix(contentType, "video/") {
c.Response().Header().Set("Color-Gamut", "srgb, p3, rec2020")
}
// Force download for non-media files to prevent XSS execution
if !strings.HasPrefix(contentType, "image/") &&
!strings.HasPrefix(contentType, "video/") &&
!strings.HasPrefix(contentType, "audio/") &&
contentType != "application/pdf" {
c.Response().Header().Set("Content-Disposition", fmt.Sprintf("attachment; filename=%q", attachment.Filename))
}
// For video/audio: Use http.ServeContent for automatic range request support
// This is critical for Safari which REQUIRES range request support
if strings.HasPrefix(contentType, "video/") || strings.HasPrefix(contentType, "audio/") {
// ServeContent automatically handles:
// - Range request parsing
// - HTTP 206 Partial Content responses
// - Content-Range headers
// - Accept-Ranges: bytes header
modTime := time.Unix(attachment.UpdatedTs, 0)
http.ServeContent(c.Response(), c.Request(), attachment.Filename, modTime, bytes.NewReader(blob))
return nil
}
// For other files: Simple blob response
return c.Blob(http.StatusOK, contentType, blob)
return s.serveStaticFile(c, attachment, contentType, wantThumbnail)
}
// serveUserAvatar serves user avatar images.
// Supports both user ID and username as identifier.
func (s *FileServerService) serveUserAvatar(c echo.Context) error {
ctx := c.Request().Context()
identifier := c.Param("identifier")
// Try to find user by ID or username
user, err := s.getUserByIdentifier(ctx, identifier)
if err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, "failed to get user").SetInternal(err)
@ -194,79 +164,327 @@ func (s *FileServerService) serveUserAvatar(c echo.Context) error {
return echo.NewHTTPError(http.StatusNotFound, "avatar not found")
}
// Extract image info from data URI
imageType, base64Data, err := s.extractImageInfo(user.AvatarURL)
imageType, imageData, err := s.parseDataURI(user.AvatarURL)
if err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, "failed to extract image info").SetInternal(err)
return echo.NewHTTPError(http.StatusInternalServerError, "failed to parse avatar data").SetInternal(err)
}
// Validate avatar MIME type to prevent XSS
// Supports standard formats and HDR-capable formats
allowedAvatarTypes := map[string]bool{
"image/png": true,
"image/jpeg": true,
"image/jpg": true,
"image/gif": true,
"image/webp": true,
"image/heic": true,
"image/heif": true,
}
if !allowedAvatarTypes[imageType] {
if !avatarAllowedTypes[imageType] {
return echo.NewHTTPError(http.StatusBadRequest, "invalid avatar image type")
}
// Decode base64 data
imageData, err := base64.StdEncoding.DecodeString(base64Data)
if err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, "failed to decode image data").SetInternal(err)
}
// Set cache headers for avatars
c.Response().Header().Set("Content-Type", imageType)
c.Response().Header().Set("Cache-Control", "public, max-age=3600")
c.Response().Header().Set("X-Content-Type-Options", "nosniff")
// Defense-in-depth: prevent embedding in frames
c.Response().Header().Set("X-Frame-Options", "DENY")
c.Response().Header().Set("Content-Security-Policy", "default-src 'none'; style-src 'unsafe-inline';")
setSecurityHeaders(c)
c.Response().Header().Set(echo.HeaderContentType, imageType)
c.Response().Header().Set(echo.HeaderCacheControl, cacheMaxAge)
return c.Blob(http.StatusOK, imageType, imageData)
}
// getUserByIdentifier finds a user by either ID or username.
func (s *FileServerService) getUserByIdentifier(ctx context.Context, identifier string) (*store.User, error) {
// Try to parse as ID first
if userID, err := util.ConvertStringToInt32(identifier); err == nil {
return s.Store.GetUser(ctx, &store.FindUser{ID: &userID})
}
// =============================================================================
// File Serving Methods
// =============================================================================
// Otherwise, treat as username
return s.Store.GetUser(ctx, &store.FindUser{Username: &identifier})
// serveMediaStream serves video/audio files using streaming to avoid memory exhaustion.
func (s *FileServerService) serveMediaStream(c echo.Context, attachment *store.Attachment, contentType string) error {
setSecurityHeaders(c)
setMediaHeaders(c, contentType, attachment.Type)
switch attachment.StorageType {
case storepb.AttachmentStorageType_LOCAL:
filePath, err := s.resolveLocalPath(attachment.Reference)
if err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, "failed to resolve file path").SetInternal(err)
}
http.ServeFile(c.Response(), c.Request(), filePath)
return nil
case storepb.AttachmentStorageType_S3:
presignURL, err := s.getS3PresignedURL(c.Request().Context(), attachment)
if err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, "failed to generate presigned URL").SetInternal(err)
}
return c.Redirect(http.StatusTemporaryRedirect, presignURL)
default:
// Database storage fallback.
modTime := time.Unix(attachment.UpdatedTs, 0)
http.ServeContent(c.Response(), c.Request(), attachment.Filename, modTime, bytes.NewReader(attachment.Blob))
return nil
}
}
// extractImageInfo extracts image type and base64 data from a data URI.
// Data URI format: ...
func (*FileServerService) extractImageInfo(dataURI string) (string, string, error) {
dataURIRegex := regexp.MustCompile(`^data:(?P<type>.+);base64,(?P<base64>.+)`)
matches := dataURIRegex.FindStringSubmatch(dataURI)
if len(matches) != 3 {
return "", "", errors.New("invalid data URI format")
// serveStaticFile serves non-streaming files (images, documents, etc.).
func (s *FileServerService) serveStaticFile(c echo.Context, attachment *store.Attachment, contentType string, wantThumbnail bool) error {
blob, err := s.getAttachmentBlob(attachment)
if err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, "failed to get attachment blob").SetInternal(err)
}
imageType := matches[1]
base64Data := matches[2]
return imageType, base64Data, nil
// Generate thumbnail for supported image types.
if wantThumbnail && thumbnailSupportedTypes[attachment.Type] {
if thumbnailBlob, err := s.getOrGenerateThumbnail(c.Request().Context(), attachment); err != nil {
c.Logger().Warnf("failed to get thumbnail: %v", err)
} else {
blob = thumbnailBlob
}
}
setSecurityHeaders(c)
setMediaHeaders(c, contentType, attachment.Type)
// Force download for non-media files to prevent XSS execution.
if !strings.HasPrefix(contentType, "image/") && contentType != "application/pdf" {
c.Response().Header().Set(echo.HeaderContentDisposition, fmt.Sprintf("attachment; filename=%q", attachment.Filename))
}
return c.Blob(http.StatusOK, contentType, blob)
}
// =============================================================================
// Storage Operations
// =============================================================================
// getAttachmentBlob retrieves the binary content of an attachment from storage.
func (s *FileServerService) getAttachmentBlob(attachment *store.Attachment) ([]byte, error) {
switch attachment.StorageType {
case storepb.AttachmentStorageType_LOCAL:
return s.readLocalFile(attachment.Reference)
case storepb.AttachmentStorageType_S3:
return s.downloadFromS3(attachment)
default:
return attachment.Blob, nil
}
}
// getAttachmentReader returns a reader for streaming attachment content.
func (s *FileServerService) getAttachmentReader(attachment *store.Attachment) (io.ReadCloser, error) {
switch attachment.StorageType {
case storepb.AttachmentStorageType_LOCAL:
filePath, err := s.resolveLocalPath(attachment.Reference)
if err != nil {
return nil, err
}
file, err := os.Open(filePath)
if err != nil {
if os.IsNotExist(err) {
return nil, errors.Wrap(err, "file not found")
}
return nil, errors.Wrap(err, "failed to open file")
}
return file, nil
case storepb.AttachmentStorageType_S3:
s3Client, s3Object, err := s.createS3Client(attachment)
if err != nil {
return nil, err
}
reader, err := s3Client.GetObjectStream(context.Background(), s3Object.Key)
if err != nil {
return nil, errors.Wrap(err, "failed to stream from S3")
}
return reader, nil
default:
return io.NopCloser(bytes.NewReader(attachment.Blob)), nil
}
}
// resolveLocalPath converts a storage reference to an absolute file path.
func (s *FileServerService) resolveLocalPath(reference string) (string, error) {
filePath := filepath.FromSlash(reference)
if !filepath.IsAbs(filePath) {
filePath = filepath.Join(s.Profile.Data, filePath)
}
return filePath, nil
}
// readLocalFile reads the entire contents of a local file.
func (s *FileServerService) readLocalFile(reference string) ([]byte, error) {
filePath, err := s.resolveLocalPath(reference)
if err != nil {
return nil, err
}
file, err := os.Open(filePath)
if err != nil {
if os.IsNotExist(err) {
return nil, errors.Wrap(err, "file not found")
}
return nil, errors.Wrap(err, "failed to open file")
}
defer file.Close()
blob, err := io.ReadAll(file)
if err != nil {
return nil, errors.Wrap(err, "failed to read file")
}
return blob, nil
}
// createS3Client creates an S3 client from attachment payload.
func (*FileServerService) createS3Client(attachment *store.Attachment) (*s3.Client, *storepb.AttachmentPayload_S3Object, error) {
if attachment.Payload == nil {
return nil, nil, errors.New("attachment payload is missing")
}
s3Object := attachment.Payload.GetS3Object()
if s3Object == nil {
return nil, nil, errors.New("S3 object payload is missing")
}
if s3Object.S3Config == nil {
return nil, nil, errors.New("S3 config is missing")
}
if s3Object.Key == "" {
return nil, nil, errors.New("S3 object key is missing")
}
client, err := s3.NewClient(context.Background(), s3Object.S3Config)
if err != nil {
return nil, nil, errors.Wrap(err, "failed to create S3 client")
}
return client, s3Object, nil
}
// downloadFromS3 downloads the entire object from S3.
func (s *FileServerService) downloadFromS3(attachment *store.Attachment) ([]byte, error) {
client, s3Object, err := s.createS3Client(attachment)
if err != nil {
return nil, err
}
blob, err := client.GetObject(context.Background(), s3Object.Key)
if err != nil {
return nil, errors.Wrap(err, "failed to download from S3")
}
return blob, nil
}
// getS3PresignedURL generates a presigned URL for direct S3 access.
func (s *FileServerService) getS3PresignedURL(ctx context.Context, attachment *store.Attachment) (string, error) {
client, s3Object, err := s.createS3Client(attachment)
if err != nil {
return "", err
}
url, err := client.PresignGetObject(ctx, s3Object.Key)
if err != nil {
return "", errors.Wrap(err, "failed to presign URL")
}
return url, nil
}
// =============================================================================
// Thumbnail Generation
// =============================================================================
// getOrGenerateThumbnail returns the thumbnail image of the attachment.
// Uses semaphore to limit concurrent thumbnail generation and prevent memory exhaustion.
func (s *FileServerService) getOrGenerateThumbnail(ctx context.Context, attachment *store.Attachment) ([]byte, error) {
thumbnailPath, err := s.getThumbnailPath(attachment)
if err != nil {
return nil, err
}
// Fast path: return cached thumbnail if exists.
if blob, err := s.readCachedThumbnail(thumbnailPath); err == nil {
return blob, nil
}
// Acquire semaphore to limit concurrent generation.
if err := s.thumbnailSemaphore.Acquire(ctx, 1); err != nil {
return nil, errors.Wrap(err, "failed to acquire semaphore")
}
defer s.thumbnailSemaphore.Release(1)
// Double-check after acquiring semaphore (another goroutine may have generated it).
if blob, err := s.readCachedThumbnail(thumbnailPath); err == nil {
return blob, nil
}
return s.generateThumbnail(attachment, thumbnailPath)
}
// getThumbnailPath returns the file path for a cached thumbnail.
func (s *FileServerService) getThumbnailPath(attachment *store.Attachment) (string, error) {
cacheFolder := filepath.Join(s.Profile.Data, ThumbnailCacheFolder)
if err := os.MkdirAll(cacheFolder, os.ModePerm); err != nil {
return "", errors.Wrap(err, "failed to create thumbnail cache folder")
}
filename := fmt.Sprintf("%d%s", attachment.ID, filepath.Ext(attachment.Filename))
return filepath.Join(cacheFolder, filename), nil
}
// readCachedThumbnail reads a thumbnail from the cache directory.
func (*FileServerService) readCachedThumbnail(path string) ([]byte, error) {
file, err := os.Open(path)
if err != nil {
return nil, err
}
defer file.Close()
return io.ReadAll(file)
}
// generateThumbnail creates a new thumbnail and saves it to disk.
func (s *FileServerService) generateThumbnail(attachment *store.Attachment, thumbnailPath string) ([]byte, error) {
reader, err := s.getAttachmentReader(attachment)
if err != nil {
return nil, errors.Wrap(err, "failed to get attachment reader")
}
defer reader.Close()
img, err := imaging.Decode(reader, imaging.AutoOrientation(true))
if err != nil {
return nil, errors.Wrap(err, "failed to decode image")
}
width, height := img.Bounds().Dx(), img.Bounds().Dy()
thumbnailWidth, thumbnailHeight := calculateThumbnailDimensions(width, height)
thumbnailImage := imaging.Resize(img, thumbnailWidth, thumbnailHeight, imaging.Lanczos)
if err := imaging.Save(thumbnailImage, thumbnailPath); err != nil {
return nil, errors.Wrap(err, "failed to save thumbnail")
}
return s.readCachedThumbnail(thumbnailPath)
}
// calculateThumbnailDimensions calculates the target dimensions for a thumbnail.
// The largest dimension is constrained to thumbnailMaxSize while maintaining aspect ratio.
// Small images are not enlarged.
func calculateThumbnailDimensions(width, height int) (int, int) {
if max(width, height) <= thumbnailMaxSize {
return width, height
}
if width >= height {
return thumbnailMaxSize, 0 // Landscape: constrain width.
}
return 0, thumbnailMaxSize // Portrait: constrain height.
}
// =============================================================================
// Authentication & Authorization
// =============================================================================
// checkAttachmentPermission verifies the user has permission to access the attachment.
func (s *FileServerService) checkAttachmentPermission(ctx context.Context, c echo.Context, attachment *store.Attachment) error {
// If attachment is not linked to a memo, allow access
// For unlinked attachments, only the creator can access.
if attachment.MemoID == nil {
user, err := s.getCurrentUser(ctx, c)
if err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, "failed to get current user").SetInternal(err)
}
if user == nil {
return echo.NewHTTPError(http.StatusUnauthorized, "unauthorized access")
}
if user.ID != attachment.CreatorID && user.Role != store.RoleAdmin {
return echo.NewHTTPError(http.StatusForbidden, "forbidden access")
}
return nil
}
// Check memo visibility
memo, err := s.Store.GetMemo(ctx, &store.FindMemo{
ID: attachment.MemoID,
})
memo, err := s.Store.GetMemo(ctx, &store.FindMemo{ID: attachment.MemoID})
if err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, "failed to find memo").SetInternal(err)
}
@ -274,12 +492,10 @@ func (s *FileServerService) checkAttachmentPermission(ctx context.Context, c ech
return echo.NewHTTPError(http.StatusNotFound, "memo not found")
}
// Public memos are accessible to everyone
if memo.Visibility == store.Public {
return nil
}
// For non-public memos, check authentication
user, err := s.getCurrentUser(ctx, c)
if err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, "failed to get current user").SetInternal(err)
@ -288,278 +504,132 @@ func (s *FileServerService) checkAttachmentPermission(ctx context.Context, c ech
return echo.NewHTTPError(http.StatusUnauthorized, "unauthorized access")
}
// Private memos can only be accessed by the creator
if memo.Visibility == store.Private && user.ID != attachment.CreatorID {
if memo.Visibility == store.Private && user.ID != memo.CreatorID && user.Role != store.RoleAdmin {
return echo.NewHTTPError(http.StatusForbidden, "forbidden access")
}
return nil
}
// getCurrentUser retrieves the current authenticated user from the Echo context.
// getCurrentUser retrieves the current authenticated user from the request.
// Authentication priority: Bearer token (Access Token V2 or PAT) > Refresh token cookie.
// Uses the shared Authenticator for consistent authentication logic.
func (s *FileServerService) getCurrentUser(ctx context.Context, c echo.Context) (*store.User, error) {
// Try Bearer token authentication first
authHeader := c.Request().Header.Get("Authorization")
if authHeader != "" {
token := auth.ExtractBearerToken(authHeader)
if token != "" {
// Try Access Token V2 (stateless)
if !strings.HasPrefix(token, auth.PersonalAccessTokenPrefix) {
claims, err := s.authenticator.AuthenticateByAccessTokenV2(token)
if err == nil && claims != nil {
// Get user from claims
user, err := s.Store.GetUser(ctx, &store.FindUser{ID: &claims.UserID})
if err == nil && user != nil {
return user, nil
}
}
}
// Try PAT
if strings.HasPrefix(token, auth.PersonalAccessTokenPrefix) {
user, _, err := s.authenticator.AuthenticateByPAT(ctx, token)
if err == nil && user != nil {
return user, nil
}
}
// Try Bearer token authentication.
if authHeader := c.Request().Header.Get(echo.HeaderAuthorization); authHeader != "" {
if user, err := s.authenticateByBearerToken(ctx, authHeader); err == nil && user != nil {
return user, nil
}
}
// Fallback: Try refresh token cookie authentication
// This allows protected attachments to load even when access token has expired,
// as long as the user has a valid refresh token cookie.
cookieHeader := c.Request().Header.Get("Cookie")
if cookieHeader != "" {
refreshToken := auth.ExtractRefreshTokenFromCookie(cookieHeader)
if refreshToken != "" {
user, _, err := s.authenticator.AuthenticateByRefreshToken(ctx, refreshToken)
if err == nil && user != nil {
return user, nil
}
// Fallback: Try refresh token cookie.
if cookieHeader := c.Request().Header.Get("Cookie"); cookieHeader != "" {
if user, err := s.authenticateByRefreshToken(ctx, cookieHeader); err == nil && user != nil {
return user, nil
}
}
// No valid authentication found
return nil, nil
}
// isImageType checks if the mime type is an image that supports thumbnails.
// Supports standard formats (PNG, JPEG) and HDR-capable formats (HEIC, HEIF, WebP).
func (*FileServerService) isImageType(mimeType string) bool {
supportedTypes := map[string]bool{
"image/png": true,
"image/jpeg": true,
"image/heic": true,
"image/heif": true,
"image/webp": true,
// authenticateByBearerToken authenticates using Authorization header.
func (s *FileServerService) authenticateByBearerToken(ctx context.Context, authHeader string) (*store.User, error) {
token := auth.ExtractBearerToken(authHeader)
if token == "" {
return nil, nil
}
return supportedTypes[mimeType]
// Try Access Token V2 (stateless JWT).
if !strings.HasPrefix(token, auth.PersonalAccessTokenPrefix) {
claims, err := s.authenticator.AuthenticateByAccessTokenV2(token)
if err == nil && claims != nil {
return s.Store.GetUser(ctx, &store.FindUser{ID: &claims.UserID})
}
}
// Try Personal Access Token (stateful).
if strings.HasPrefix(token, auth.PersonalAccessTokenPrefix) {
user, _, err := s.authenticator.AuthenticateByPAT(ctx, token)
if err == nil {
return user, nil
}
}
return nil, nil
}
// getAttachmentReader returns a reader for the attachment content.
func (s *FileServerService) getAttachmentReader(attachment *store.Attachment) (io.ReadCloser, error) {
// For local storage, read the file from the local disk.
if attachment.StorageType == storepb.AttachmentStorageType_LOCAL {
attachmentPath := filepath.FromSlash(attachment.Reference)
if !filepath.IsAbs(attachmentPath) {
attachmentPath = filepath.Join(s.Profile.Data, attachmentPath)
}
file, err := os.Open(attachmentPath)
if err != nil {
if os.IsNotExist(err) {
return nil, errors.Wrap(err, "file not found")
}
return nil, errors.Wrap(err, "failed to open the file")
}
return file, nil
// authenticateByRefreshToken authenticates using refresh token cookie.
func (s *FileServerService) authenticateByRefreshToken(ctx context.Context, cookieHeader string) (*store.User, error) {
refreshToken := auth.ExtractRefreshTokenFromCookie(cookieHeader)
if refreshToken == "" {
return nil, nil
}
// For S3 storage, download the file from S3.
if attachment.StorageType == storepb.AttachmentStorageType_S3 {
if attachment.Payload == nil {
return nil, errors.New("attachment payload is missing")
}
s3Object := attachment.Payload.GetS3Object()
if s3Object == nil {
return nil, errors.New("S3 object payload is missing")
}
if s3Object.S3Config == nil {
return nil, errors.New("S3 config is missing")
}
if s3Object.Key == "" {
return nil, errors.New("S3 object key is missing")
}
s3Client, err := s3.NewClient(context.Background(), s3Object.S3Config)
if err != nil {
return nil, errors.Wrap(err, "failed to create S3 client")
}
reader, err := s3Client.GetObjectStream(context.Background(), s3Object.Key)
if err != nil {
return nil, errors.Wrap(err, "failed to get object from S3")
}
return reader, nil
}
// For database storage, return the blob from the database.
return io.NopCloser(bytes.NewReader(attachment.Blob)), nil
user, _, err := s.authenticator.AuthenticateByRefreshToken(ctx, refreshToken)
return user, err
}
// getAttachmentBlob retrieves the binary content of an attachment from storage.
func (s *FileServerService) getAttachmentBlob(attachment *store.Attachment) ([]byte, error) {
// For local storage, read the file from the local disk.
if attachment.StorageType == storepb.AttachmentStorageType_LOCAL {
attachmentPath := filepath.FromSlash(attachment.Reference)
if !filepath.IsAbs(attachmentPath) {
attachmentPath = filepath.Join(s.Profile.Data, attachmentPath)
}
file, err := os.Open(attachmentPath)
if err != nil {
if os.IsNotExist(err) {
return nil, errors.Wrap(err, "file not found")
}
return nil, errors.Wrap(err, "failed to open the file")
}
defer file.Close()
blob, err := io.ReadAll(file)
if err != nil {
return nil, errors.Wrap(err, "failed to read the file")
}
return blob, nil
// getUserByIdentifier finds a user by either ID or username.
func (s *FileServerService) getUserByIdentifier(ctx context.Context, identifier string) (*store.User, error) {
if userID, err := util.ConvertStringToInt32(identifier); err == nil {
return s.Store.GetUser(ctx, &store.FindUser{ID: &userID})
}
// For S3 storage, download the file from S3.
if attachment.StorageType == storepb.AttachmentStorageType_S3 {
if attachment.Payload == nil {
return nil, errors.New("attachment payload is missing")
}
s3Object := attachment.Payload.GetS3Object()
if s3Object == nil {
return nil, errors.New("S3 object payload is missing")
}
if s3Object.S3Config == nil {
return nil, errors.New("S3 config is missing")
}
if s3Object.Key == "" {
return nil, errors.New("S3 object key is missing")
}
s3Client, err := s3.NewClient(context.Background(), s3Object.S3Config)
if err != nil {
return nil, errors.Wrap(err, "failed to create S3 client")
}
blob, err := s3Client.GetObject(context.Background(), s3Object.Key)
if err != nil {
return nil, errors.Wrap(err, "failed to get object from S3")
}
return blob, nil
}
// For database storage, return the blob from the database.
return attachment.Blob, nil
return s.Store.GetUser(ctx, &store.FindUser{Username: &identifier})
}
// getOrGenerateThumbnail returns the thumbnail image of the attachment.
// Uses semaphore to limit concurrent thumbnail generation and prevent memory exhaustion.
func (s *FileServerService) getOrGenerateThumbnail(ctx context.Context, attachment *store.Attachment) ([]byte, error) {
thumbnailCacheFolder := filepath.Join(s.Profile.Data, ThumbnailCacheFolder)
if err := os.MkdirAll(thumbnailCacheFolder, os.ModePerm); err != nil {
return nil, errors.Wrap(err, "failed to create thumbnail cache folder")
}
filePath := filepath.Join(thumbnailCacheFolder, fmt.Sprintf("%d%s", attachment.ID, filepath.Ext(attachment.Filename)))
// =============================================================================
// Helper Functions
// =============================================================================
// Check if thumbnail already exists
if _, err := os.Stat(filePath); err == nil {
// Thumbnail exists, read and return it
thumbnailFile, err := os.Open(filePath)
if err != nil {
return nil, errors.Wrap(err, "failed to open thumbnail file")
}
defer thumbnailFile.Close()
blob, err := io.ReadAll(thumbnailFile)
if err != nil {
return nil, errors.Wrap(err, "failed to read thumbnail file")
}
return blob, nil
} else if !os.IsNotExist(err) {
return nil, errors.Wrap(err, "failed to check thumbnail image stat")
// sanitizeContentType converts potentially dangerous MIME types to safe alternatives.
func (*FileServerService) sanitizeContentType(mimeType string) string {
contentType := mimeType
if strings.HasPrefix(contentType, "text/") {
contentType += "; charset=utf-8"
}
// Thumbnail doesn't exist, acquire semaphore to limit concurrent generation
if err := s.thumbnailSemaphore.Acquire(ctx, 1); err != nil {
return nil, errors.Wrap(err, "failed to acquire thumbnail generation semaphore")
// Normalize for case-insensitive lookup.
if xssUnsafeTypes[strings.ToLower(mimeType)] {
return "application/octet-stream"
}
return contentType
}
// parseDataURI extracts MIME type and decoded data from a data URI.
func (*FileServerService) parseDataURI(dataURI string) (string, []byte, error) {
matches := dataURIRegex.FindStringSubmatch(dataURI)
if len(matches) != 3 {
return "", nil, errors.New("invalid data URI format")
}
imageType := matches[1]
imageData, err := base64.StdEncoding.DecodeString(matches[2])
if err != nil {
return "", nil, errors.Wrap(err, "failed to decode base64 data")
}
return imageType, imageData, nil
}
// isMediaType checks if the MIME type is video or audio.
func isMediaType(mimeType string) bool {
return strings.HasPrefix(mimeType, "video/") || strings.HasPrefix(mimeType, "audio/")
}
// setSecurityHeaders sets common security headers for all responses.
func setSecurityHeaders(c echo.Context) {
h := c.Response().Header()
h.Set("X-Content-Type-Options", "nosniff")
h.Set("X-Frame-Options", "DENY")
h.Set("Content-Security-Policy", "default-src 'none'; style-src 'unsafe-inline';")
}
// setMediaHeaders sets headers for media file responses.
func setMediaHeaders(c echo.Context, contentType, originalType string) {
h := c.Response().Header()
h.Set(echo.HeaderContentType, contentType)
h.Set(echo.HeaderCacheControl, cacheMaxAge)
// Support HDR/wide color gamut for images and videos.
if strings.HasPrefix(originalType, "image/") || strings.HasPrefix(originalType, "video/") {
h.Set("Color-Gamut", "srgb, p3, rec2020")
}
defer s.thumbnailSemaphore.Release(1)
// Double-check if thumbnail was created while waiting for semaphore
if _, err := os.Stat(filePath); err == nil {
thumbnailFile, err := os.Open(filePath)
if err != nil {
return nil, errors.Wrap(err, "failed to open thumbnail file")
}
defer thumbnailFile.Close()
blob, err := io.ReadAll(thumbnailFile)
if err != nil {
return nil, errors.Wrap(err, "failed to read thumbnail file")
}
return blob, nil
}
// Generate the thumbnail
reader, err := s.getAttachmentReader(attachment)
if err != nil {
return nil, errors.Wrap(err, "failed to get attachment reader")
}
defer reader.Close()
// Decode image - this is memory intensive
img, err := imaging.Decode(reader, imaging.AutoOrientation(true))
if err != nil {
return nil, errors.Wrap(err, "failed to decode thumbnail image")
}
// The largest dimension is set to thumbnailMaxSize and the smaller dimension is scaled proportionally.
// Small images are not enlarged.
width := img.Bounds().Dx()
height := img.Bounds().Dy()
var thumbnailWidth, thumbnailHeight int
// Only resize if the image is larger than thumbnailMaxSize
if max(width, height) > thumbnailMaxSize {
if width >= height {
// Landscape or square - constrain width, maintain aspect ratio for height
thumbnailWidth = thumbnailMaxSize
thumbnailHeight = 0
} else {
// Portrait - constrain height, maintain aspect ratio for width
thumbnailWidth = 0
thumbnailHeight = thumbnailMaxSize
}
} else {
// Keep original dimensions for small images
thumbnailWidth = width
thumbnailHeight = height
}
// Resize the image to the calculated dimensions.
thumbnailImage := imaging.Resize(img, thumbnailWidth, thumbnailHeight, imaging.Lanczos)
// Save thumbnail to disk
if err := imaging.Save(thumbnailImage, filePath); err != nil {
return nil, errors.Wrap(err, "failed to save thumbnail file")
}
// Read the saved thumbnail and return it
thumbnailFile, err := os.Open(filePath)
if err != nil {
return nil, errors.Wrap(err, "failed to open thumbnail file")
}
defer thumbnailFile.Close()
thumbnailBlob, err := io.ReadAll(thumbnailFile)
if err != nil {
return nil, errors.Wrap(err, "failed to read thumbnail file")
}
return thumbnailBlob, nil
}

View File

@ -52,7 +52,7 @@ func NewServer(ctx context.Context, profile *profile.Profile, store *store.Store
}
secret := "usememos"
if profile.Mode == "prod" {
if !profile.Demo {
secret = instanceBasicSetting.SecretKey
}
s.Secret = secret

24
store/cache/cache.go vendored
View File

@ -161,20 +161,20 @@ func (c *Cache) Delete(_ context.Context, key string) {
// Clear removes all values from the cache.
func (c *Cache) Clear(_ context.Context) {
if c.config.OnEviction != nil {
c.data.Range(func(key, value any) bool {
count := 0
c.data.Range(func(key, value any) bool {
if c.config.OnEviction != nil {
itm, ok := value.(item)
if !ok {
return true
if ok {
if keyStr, ok := key.(string); ok {
c.config.OnEviction(keyStr, itm.value)
}
}
if keyStr, ok := key.(string); ok {
c.config.OnEviction(keyStr, itm.value)
}
return true
})
}
c.data = sync.Map{}
}
c.data.Delete(key)
count++
return true
})
(&c.itemCount).Store(0)
}

View File

@ -31,7 +31,7 @@ func (d *DB) CreateAttachment(ctx context.Context, create *store.Attachment) (*s
}
args := []any{create.UID, create.Filename, create.Blob, create.Type, create.Size, create.CreatorID, create.MemoID, storageType, create.Reference, payloadString}
stmt := "INSERT INTO `resource` (" + strings.Join(fields, ", ") + ") VALUES (" + strings.Join(placeholder, ", ") + ")"
stmt := "INSERT INTO `attachment` (" + strings.Join(fields, ", ") + ") VALUES (" + strings.Join(placeholder, ", ") + ")"
result, err := d.db.ExecContext(ctx, stmt, args...)
if err != nil {
return nil, err
@ -50,38 +50,38 @@ func (d *DB) ListAttachments(ctx context.Context, find *store.FindAttachment) ([
where, args := []string{"1 = 1"}, []any{}
if v := find.ID; v != nil {
where, args = append(where, "`resource`.`id` = ?"), append(args, *v)
where, args = append(where, "`attachment`.`id` = ?"), append(args, *v)
}
if v := find.UID; v != nil {
where, args = append(where, "`resource`.`uid` = ?"), append(args, *v)
where, args = append(where, "`attachment`.`uid` = ?"), append(args, *v)
}
if v := find.CreatorID; v != nil {
where, args = append(where, "`resource`.`creator_id` = ?"), append(args, *v)
where, args = append(where, "`attachment`.`creator_id` = ?"), append(args, *v)
}
if v := find.Filename; v != nil {
where, args = append(where, "`resource`.`filename` = ?"), append(args, *v)
where, args = append(where, "`attachment`.`filename` = ?"), append(args, *v)
}
if v := find.FilenameSearch; v != nil {
where, args = append(where, "`resource`.`filename` LIKE ?"), append(args, "%"+*v+"%")
where, args = append(where, "`attachment`.`filename` LIKE ?"), append(args, "%"+*v+"%")
}
if v := find.MemoID; v != nil {
where, args = append(where, "`resource`.`memo_id` = ?"), append(args, *v)
where, args = append(where, "`attachment`.`memo_id` = ?"), append(args, *v)
}
if len(find.MemoIDList) > 0 {
placeholders := make([]string, 0, len(find.MemoIDList))
for range find.MemoIDList {
placeholders = append(placeholders, "?")
}
where = append(where, "`resource`.`memo_id` IN ("+strings.Join(placeholders, ",")+")")
where = append(where, "`attachment`.`memo_id` IN ("+strings.Join(placeholders, ",")+")")
for _, id := range find.MemoIDList {
args = append(args, id)
}
}
if find.HasRelatedMemo {
where = append(where, "`resource`.`memo_id` IS NOT NULL")
where = append(where, "`attachment`.`memo_id` IS NOT NULL")
}
if find.StorageType != nil {
where, args = append(where, "`resource`.`storage_type` = ?"), append(args, find.StorageType.String())
where, args = append(where, "`attachment`.`storage_type` = ?"), append(args, find.StorageType.String())
}
if len(find.Filters) > 0 {
@ -95,26 +95,26 @@ func (d *DB) ListAttachments(ctx context.Context, find *store.FindAttachment) ([
}
fields := []string{
"`resource`.`id` AS `id`",
"`resource`.`uid` AS `uid`",
"`resource`.`filename` AS `filename`",
"`resource`.`type` AS `type`",
"`resource`.`size` AS `size`",
"`resource`.`creator_id` AS `creator_id`",
"UNIX_TIMESTAMP(`resource`.`created_ts`) AS `created_ts`",
"UNIX_TIMESTAMP(`resource`.`updated_ts`) AS `updated_ts`",
"`resource`.`memo_id` AS `memo_id`",
"`resource`.`storage_type` AS `storage_type`",
"`resource`.`reference` AS `reference`",
"`resource`.`payload` AS `payload`",
"`attachment`.`id` AS `id`",
"`attachment`.`uid` AS `uid`",
"`attachment`.`filename` AS `filename`",
"`attachment`.`type` AS `type`",
"`attachment`.`size` AS `size`",
"`attachment`.`creator_id` AS `creator_id`",
"UNIX_TIMESTAMP(`attachment`.`created_ts`) AS `created_ts`",
"UNIX_TIMESTAMP(`attachment`.`updated_ts`) AS `updated_ts`",
"`attachment`.`memo_id` AS `memo_id`",
"`attachment`.`storage_type` AS `storage_type`",
"`attachment`.`reference` AS `reference`",
"`attachment`.`payload` AS `payload`",
"CASE WHEN `memo`.`uid` IS NOT NULL THEN `memo`.`uid` ELSE NULL END AS `memo_uid`",
}
if find.GetBlob {
fields = append(fields, "`resource`.`blob` AS `blob`")
fields = append(fields, "`attachment`.`blob` AS `blob`")
}
query := "SELECT " + strings.Join(fields, ", ") + " FROM `resource`" + " " +
"LEFT JOIN `memo` ON `resource`.`memo_id` = `memo`.`id`" + " " +
query := "SELECT " + strings.Join(fields, ", ") + " FROM `attachment`" + " " +
"LEFT JOIN `memo` ON `attachment`.`memo_id` = `memo`.`id`" + " " +
"WHERE " + strings.Join(where, " AND ") + " " +
"ORDER BY `updated_ts` DESC"
if find.Limit != nil {
@ -216,7 +216,7 @@ func (d *DB) UpdateAttachment(ctx context.Context, update *store.UpdateAttachmen
}
args = append(args, update.ID)
stmt := "UPDATE `resource` SET " + strings.Join(set, ", ") + " WHERE `id` = ?"
stmt := "UPDATE `attachment` SET " + strings.Join(set, ", ") + " WHERE `id` = ?"
result, err := d.db.ExecContext(ctx, stmt, args...)
if err != nil {
return err
@ -228,7 +228,7 @@ func (d *DB) UpdateAttachment(ctx context.Context, update *store.UpdateAttachmen
}
func (d *DB) DeleteAttachment(ctx context.Context, delete *store.DeleteAttachment) error {
stmt := "DELETE FROM `resource` WHERE `id` = ?"
stmt := "DELETE FROM `attachment` WHERE `id` = ?"
result, err := d.db.ExecContext(ctx, stmt, delete.ID)
if err != nil {
return err

View File

@ -63,7 +63,11 @@ func (d *DB) ListInboxes(ctx context.Context, find *store.FindInbox) ([]*store.I
if find.MessageType != nil {
// Filter by message type using JSON extraction
// Note: The type field in JSON is stored as string representation of the enum name
where, args = append(where, "JSON_EXTRACT(`message`, '$.type') = ?"), append(args, find.MessageType.String())
if *find.MessageType == storepb.InboxMessage_TYPE_UNSPECIFIED {
where, args = append(where, "(JSON_EXTRACT(`message`, '$.type') IS NULL OR JSON_EXTRACT(`message`, '$.type') = ?)"), append(args, find.MessageType.String())
} else {
where, args = append(where, "JSON_EXTRACT(`message`, '$.type') = ?"), append(args, find.MessageType.String())
}
}
query := "SELECT `id`, UNIX_TIMESTAMP(`created_ts`), `sender_id`, `receiver_id`, `status`, `message` FROM `inbox` WHERE " + strings.Join(where, " AND ") + " ORDER BY `created_ts` DESC"

View File

@ -26,6 +26,18 @@ func (d *DB) CreateMemo(ctx context.Context, create *store.Memo) (*store.Memo, e
}
args := []any{create.UID, create.CreatorID, create.Content, create.Visibility, payload}
// Add custom timestamps if provided
if create.CreatedTs != 0 {
fields = append(fields, "`created_ts`")
placeholder = append(placeholder, "?")
args = append(args, create.CreatedTs)
}
if create.UpdatedTs != 0 {
fields = append(fields, "`updated_ts`")
placeholder = append(placeholder, "?")
args = append(args, create.UpdatedTs)
}
stmt := "INSERT INTO `memo` (" + strings.Join(fields, ", ") + ") VALUES (" + strings.Join(placeholder, ", ") + ")"
result, err := d.db.ExecContext(ctx, stmt, args...)
if err != nil {

View File

@ -1,162 +0,0 @@
package mysql
import (
"context"
"testing"
"time"
"github.com/stretchr/testify/require"
"github.com/usememos/memos/plugin/filter"
)
func TestConvertExprToSQL(t *testing.T) {
tests := []struct {
filter string
want string
args []any
}{
{
filter: `tag in ["tag1", "tag2"]`,
want: "((JSON_CONTAINS(JSON_EXTRACT(`memo`.`payload`, '$.tags'), ?) OR JSON_EXTRACT(`memo`.`payload`, '$.tags') LIKE ?) OR (JSON_CONTAINS(JSON_EXTRACT(`memo`.`payload`, '$.tags'), ?) OR JSON_EXTRACT(`memo`.`payload`, '$.tags') LIKE ?))",
args: []any{`"tag1"`, `%"tag1/%`, `"tag2"`, `%"tag2/%`},
},
{
filter: `!(tag in ["tag1", "tag2"])`,
want: "NOT (((JSON_CONTAINS(JSON_EXTRACT(`memo`.`payload`, '$.tags'), ?) OR JSON_EXTRACT(`memo`.`payload`, '$.tags') LIKE ?) OR (JSON_CONTAINS(JSON_EXTRACT(`memo`.`payload`, '$.tags'), ?) OR JSON_EXTRACT(`memo`.`payload`, '$.tags') LIKE ?)))",
args: []any{`"tag1"`, `%"tag1/%`, `"tag2"`, `%"tag2/%`},
},
{
filter: `content.contains("memos")`,
want: "`memo`.`content` LIKE ?",
args: []any{"%memos%"},
},
{
filter: `visibility in ["PUBLIC"]`,
want: "`memo`.`visibility` IN (?)",
args: []any{"PUBLIC"},
},
{
filter: `visibility in ["PUBLIC", "PRIVATE"]`,
want: "`memo`.`visibility` IN (?,?)",
args: []any{"PUBLIC", "PRIVATE"},
},
{
filter: `tag in ['tag1'] || content.contains('hello')`,
want: "((JSON_CONTAINS(JSON_EXTRACT(`memo`.`payload`, '$.tags'), ?) OR JSON_EXTRACT(`memo`.`payload`, '$.tags') LIKE ?) OR `memo`.`content` LIKE ?)",
args: []any{`"tag1"`, `%"tag1/%`, "%hello%"},
},
{
filter: `1`,
want: "",
args: []any{},
},
{
filter: `pinned`,
want: "`memo`.`pinned` IS TRUE",
args: []any{},
},
{
filter: `has_task_list`,
want: "JSON_EXTRACT(`memo`.`payload`, '$.property.hasTaskList') = CAST('true' AS JSON)",
args: []any{},
},
{
filter: `has_task_list == true`,
want: "JSON_EXTRACT(`memo`.`payload`, '$.property.hasTaskList') = CAST('true' AS JSON)",
args: []any{},
},
{
filter: `has_task_list != false`,
want: "JSON_EXTRACT(`memo`.`payload`, '$.property.hasTaskList') != CAST('false' AS JSON)",
args: []any{},
},
{
filter: `has_task_list == false`,
want: "JSON_EXTRACT(`memo`.`payload`, '$.property.hasTaskList') = CAST('false' AS JSON)",
args: []any{},
},
{
filter: `!has_task_list`,
want: "NOT (JSON_EXTRACT(`memo`.`payload`, '$.property.hasTaskList') = CAST('true' AS JSON))",
args: []any{},
},
{
filter: `has_task_list && pinned`,
want: "(JSON_EXTRACT(`memo`.`payload`, '$.property.hasTaskList') = CAST('true' AS JSON) AND `memo`.`pinned` IS TRUE)",
args: []any{},
},
{
filter: `has_task_list && content.contains("todo")`,
want: "(JSON_EXTRACT(`memo`.`payload`, '$.property.hasTaskList') = CAST('true' AS JSON) AND `memo`.`content` LIKE ?)",
args: []any{"%todo%"},
},
{
filter: `created_ts > now() - 60 * 60 * 24`,
want: "UNIX_TIMESTAMP(`memo`.`created_ts`) > ?",
args: []any{time.Now().Unix() - 60*60*24},
},
{
filter: `size(tags) == 0`,
want: "JSON_LENGTH(COALESCE(JSON_EXTRACT(`memo`.`payload`, '$.tags'), JSON_ARRAY())) = ?",
args: []any{int64(0)},
},
{
filter: `size(tags) > 0`,
want: "JSON_LENGTH(COALESCE(JSON_EXTRACT(`memo`.`payload`, '$.tags'), JSON_ARRAY())) > ?",
args: []any{int64(0)},
},
{
filter: `"work" in tags`,
want: "JSON_CONTAINS(JSON_EXTRACT(`memo`.`payload`, '$.tags'), ?)",
args: []any{`"work"`},
},
{
filter: `size(tags) == 2`,
want: "JSON_LENGTH(COALESCE(JSON_EXTRACT(`memo`.`payload`, '$.tags'), JSON_ARRAY())) = ?",
args: []any{int64(2)},
},
{
filter: `has_link == true`,
want: "JSON_EXTRACT(`memo`.`payload`, '$.property.hasLink') = CAST('true' AS JSON)",
args: []any{},
},
{
filter: `has_code == false`,
want: "JSON_EXTRACT(`memo`.`payload`, '$.property.hasCode') = CAST('false' AS JSON)",
args: []any{},
},
{
filter: `has_incomplete_tasks != false`,
want: "JSON_EXTRACT(`memo`.`payload`, '$.property.hasIncompleteTasks') != CAST('false' AS JSON)",
args: []any{},
},
{
filter: `has_link`,
want: "JSON_EXTRACT(`memo`.`payload`, '$.property.hasLink') = CAST('true' AS JSON)",
args: []any{},
},
{
filter: `has_code`,
want: "JSON_EXTRACT(`memo`.`payload`, '$.property.hasCode') = CAST('true' AS JSON)",
args: []any{},
},
{
filter: `has_incomplete_tasks`,
want: "JSON_EXTRACT(`memo`.`payload`, '$.property.hasIncompleteTasks') = CAST('true' AS JSON)",
args: []any{},
},
}
engine, err := filter.DefaultEngine()
require.NoError(t, err)
for _, tt := range tests {
stmt, err := engine.CompileToStatement(context.Background(), tt.filter, filter.RenderOptions{
Dialect: filter.DialectMySQL,
})
require.NoError(t, err)
require.Equal(t, tt.want, stmt.SQL)
require.Equal(t, tt.args, stmt.Args)
}
}

View File

@ -10,7 +10,7 @@ import (
)
func (d *DB) UpsertMemoRelation(ctx context.Context, create *store.MemoRelation) (*store.MemoRelation, error) {
stmt := "INSERT INTO `memo_relation` (`memo_id`, `related_memo_id`, `type`) VALUES (?, ?, ?)"
stmt := "INSERT INTO `memo_relation` (`memo_id`, `related_memo_id`, `type`) VALUES (?, ?, ?) ON DUPLICATE KEY UPDATE `type` = `type`"
_, err := d.db.ExecContext(
ctx,
stmt,

View File

@ -30,7 +30,7 @@ func (d *DB) CreateAttachment(ctx context.Context, create *store.Attachment) (*s
}
args := []any{create.UID, create.Filename, create.Blob, create.Type, create.Size, create.CreatorID, create.MemoID, storageType, create.Reference, payloadString}
stmt := "INSERT INTO resource (" + strings.Join(fields, ", ") + ") VALUES (" + placeholders(len(args)) + ") RETURNING id, created_ts, updated_ts"
stmt := "INSERT INTO attachment (" + strings.Join(fields, ", ") + ") VALUES (" + placeholders(len(args)) + ") RETURNING id, created_ts, updated_ts"
if err := d.db.QueryRowContext(ctx, stmt, args...).Scan(&create.ID, &create.CreatedTs, &create.UpdatedTs); err != nil {
return nil, err
}
@ -41,22 +41,22 @@ func (d *DB) ListAttachments(ctx context.Context, find *store.FindAttachment) ([
where, args := []string{"1 = 1"}, []any{}
if v := find.ID; v != nil {
where, args = append(where, "resource.id = "+placeholder(len(args)+1)), append(args, *v)
where, args = append(where, "attachment.id = "+placeholder(len(args)+1)), append(args, *v)
}
if v := find.UID; v != nil {
where, args = append(where, "resource.uid = "+placeholder(len(args)+1)), append(args, *v)
where, args = append(where, "attachment.uid = "+placeholder(len(args)+1)), append(args, *v)
}
if v := find.CreatorID; v != nil {
where, args = append(where, "resource.creator_id = "+placeholder(len(args)+1)), append(args, *v)
where, args = append(where, "attachment.creator_id = "+placeholder(len(args)+1)), append(args, *v)
}
if v := find.Filename; v != nil {
where, args = append(where, "resource.filename = "+placeholder(len(args)+1)), append(args, *v)
where, args = append(where, "attachment.filename = "+placeholder(len(args)+1)), append(args, *v)
}
if v := find.FilenameSearch; v != nil {
where, args = append(where, "resource.filename LIKE "+placeholder(len(args)+1)), append(args, fmt.Sprintf("%%%s%%", *v))
where, args = append(where, "attachment.filename LIKE "+placeholder(len(args)+1)), append(args, fmt.Sprintf("%%%s%%", *v))
}
if v := find.MemoID; v != nil {
where, args = append(where, "resource.memo_id = "+placeholder(len(args)+1)), append(args, *v)
where, args = append(where, "attachment.memo_id = "+placeholder(len(args)+1)), append(args, *v)
}
if len(find.MemoIDList) > 0 {
holders := make([]string, 0, len(find.MemoIDList))
@ -64,13 +64,13 @@ func (d *DB) ListAttachments(ctx context.Context, find *store.FindAttachment) ([
holders = append(holders, placeholder(len(args)+1))
args = append(args, id)
}
where = append(where, "resource.memo_id IN ("+strings.Join(holders, ", ")+")")
where = append(where, "attachment.memo_id IN ("+strings.Join(holders, ", ")+")")
}
if find.HasRelatedMemo {
where = append(where, "resource.memo_id IS NOT NULL")
where = append(where, "attachment.memo_id IS NOT NULL")
}
if v := find.StorageType; v != nil {
where, args = append(where, "resource.storage_type = "+placeholder(len(args)+1)), append(args, v.String())
where, args = append(where, "attachment.storage_type = "+placeholder(len(args)+1)), append(args, v.String())
}
if len(find.Filters) > 0 {
@ -84,31 +84,31 @@ func (d *DB) ListAttachments(ctx context.Context, find *store.FindAttachment) ([
}
fields := []string{
"resource.id AS id",
"resource.uid AS uid",
"resource.filename AS filename",
"resource.type AS type",
"resource.size AS size",
"resource.creator_id AS creator_id",
"resource.created_ts AS created_ts",
"resource.updated_ts AS updated_ts",
"resource.memo_id AS memo_id",
"resource.storage_type AS storage_type",
"resource.reference AS reference",
"resource.payload AS payload",
"attachment.id AS id",
"attachment.uid AS uid",
"attachment.filename AS filename",
"attachment.type AS type",
"attachment.size AS size",
"attachment.creator_id AS creator_id",
"attachment.created_ts AS created_ts",
"attachment.updated_ts AS updated_ts",
"attachment.memo_id AS memo_id",
"attachment.storage_type AS storage_type",
"attachment.reference AS reference",
"attachment.payload AS payload",
"CASE WHEN memo.uid IS NOT NULL THEN memo.uid ELSE NULL END AS memo_uid",
}
if find.GetBlob {
fields = append(fields, "resource.blob AS blob")
fields = append(fields, "attachment.blob AS blob")
}
query := fmt.Sprintf(`
SELECT
%s
FROM resource
LEFT JOIN memo ON resource.memo_id = memo.id
FROM attachment
LEFT JOIN memo ON attachment.memo_id = memo.id
WHERE %s
ORDER BY resource.updated_ts DESC
ORDER BY attachment.updated_ts DESC
`, strings.Join(fields, ", "), strings.Join(where, " AND "))
if find.Limit != nil {
query = fmt.Sprintf("%s LIMIT %d", query, *find.Limit)
@ -196,7 +196,7 @@ func (d *DB) UpdateAttachment(ctx context.Context, update *store.UpdateAttachmen
set, args = append(set, "payload = "+placeholder(len(args)+1)), append(args, string(bytes))
}
stmt := `UPDATE resource SET ` + strings.Join(set, ", ") + ` WHERE id = ` + placeholder(len(args)+1)
stmt := `UPDATE attachment SET ` + strings.Join(set, ", ") + ` WHERE id = ` + placeholder(len(args)+1)
args = append(args, update.ID)
result, err := d.db.ExecContext(ctx, stmt, args...)
if err != nil {
@ -209,7 +209,7 @@ func (d *DB) UpdateAttachment(ctx context.Context, update *store.UpdateAttachmen
}
func (d *DB) DeleteAttachment(ctx context.Context, delete *store.DeleteAttachment) error {
stmt := `DELETE FROM resource WHERE id = $1`
stmt := `DELETE FROM attachment WHERE id = $1`
result, err := d.db.ExecContext(ctx, stmt, delete.ID)
if err != nil {
return err

View File

@ -53,7 +53,12 @@ func (d *DB) ListInboxes(ctx context.Context, find *store.FindInbox) ([]*store.I
if find.MessageType != nil {
// Filter by message type using PostgreSQL JSON extraction
// Note: The type field in JSON is stored as string representation of the enum name
where, args = append(where, "message->>'type' = "+placeholder(len(args)+1)), append(args, find.MessageType.String())
// Cast to JSONB since the column is TEXT
if *find.MessageType == storepb.InboxMessage_TYPE_UNSPECIFIED {
where, args = append(where, "(message::JSONB->>'type' IS NULL OR message::JSONB->>'type' = "+placeholder(len(args)+1)+")"), append(args, find.MessageType.String())
} else {
where, args = append(where, "message::JSONB->>'type' = "+placeholder(len(args)+1)), append(args, find.MessageType.String())
}
}
query := "SELECT id, created_ts, sender_id, receiver_id, status, message FROM inbox WHERE " + strings.Join(where, " AND ") + " ORDER BY created_ts DESC"

View File

@ -25,6 +25,16 @@ func (d *DB) CreateMemo(ctx context.Context, create *store.Memo) (*store.Memo, e
}
args := []any{create.UID, create.CreatorID, create.Content, create.Visibility, payload}
// Add custom timestamps if provided
if create.CreatedTs != 0 {
fields = append(fields, "created_ts")
args = append(args, create.CreatedTs)
}
if create.UpdatedTs != 0 {
fields = append(fields, "updated_ts")
args = append(args, create.UpdatedTs)
}
stmt := "INSERT INTO memo (" + strings.Join(fields, ", ") + ") VALUES (" + placeholders(len(args)) + ") RETURNING id, created_ts, updated_ts, row_status"
if err := d.db.QueryRowContext(ctx, stmt, args...).Scan(
&create.ID,

View File

@ -1,160 +0,0 @@
package postgres
import (
"context"
"testing"
"time"
"github.com/stretchr/testify/require"
"github.com/usememos/memos/plugin/filter"
)
func TestConvertExprToSQL(t *testing.T) {
tests := []struct {
filter string
want string
args []any
}{
{
filter: `tag in ["tag1", "tag2"]`,
want: "((memo.payload->'tags' @> jsonb_build_array($1::json) OR (memo.payload->'tags')::text LIKE $2) OR (memo.payload->'tags' @> jsonb_build_array($3::json) OR (memo.payload->'tags')::text LIKE $4))",
args: []any{`"tag1"`, `%"tag1/%`, `"tag2"`, `%"tag2/%`},
},
{
filter: `!(tag in ["tag1", "tag2"])`,
want: "NOT (((memo.payload->'tags' @> jsonb_build_array($1::json) OR (memo.payload->'tags')::text LIKE $2) OR (memo.payload->'tags' @> jsonb_build_array($3::json) OR (memo.payload->'tags')::text LIKE $4)))",
args: []any{`"tag1"`, `%"tag1/%`, `"tag2"`, `%"tag2/%`},
},
{
filter: `content.contains("memos")`,
want: "memo.content ILIKE $1",
args: []any{"%memos%"},
},
{
filter: `visibility in ["PUBLIC"]`,
want: "memo.visibility IN ($1)",
args: []any{"PUBLIC"},
},
{
filter: `visibility in ["PUBLIC", "PRIVATE"]`,
want: "memo.visibility IN ($1,$2)",
args: []any{"PUBLIC", "PRIVATE"},
},
{
filter: `tag in ['tag1'] || content.contains('hello')`,
want: "((memo.payload->'tags' @> jsonb_build_array($1::json) OR (memo.payload->'tags')::text LIKE $2) OR memo.content ILIKE $3)",
args: []any{`"tag1"`, `%"tag1/%`, "%hello%"},
},
{
filter: `1`,
want: "",
args: []any{},
},
{
filter: `pinned`,
want: "memo.pinned IS TRUE",
args: []any{},
},
{
filter: `has_task_list`,
want: "(memo.payload->'property'->>'hasTaskList')::boolean IS TRUE",
args: []any{},
},
{
filter: `has_task_list == true`,
want: "(memo.payload->'property'->>'hasTaskList')::boolean = $1",
args: []any{true},
},
{
filter: `has_task_list != false`,
want: "(memo.payload->'property'->>'hasTaskList')::boolean != $1",
args: []any{false},
},
{
filter: `has_task_list == false`,
want: "(memo.payload->'property'->>'hasTaskList')::boolean = $1",
args: []any{false},
},
{
filter: `!has_task_list`,
want: "NOT ((memo.payload->'property'->>'hasTaskList')::boolean IS TRUE)",
args: []any{},
},
{
filter: `has_task_list && pinned`,
want: "((memo.payload->'property'->>'hasTaskList')::boolean IS TRUE AND memo.pinned IS TRUE)",
args: []any{},
},
{
filter: `has_task_list && content.contains("todo")`,
want: "((memo.payload->'property'->>'hasTaskList')::boolean IS TRUE AND memo.content ILIKE $1)",
args: []any{"%todo%"},
},
{
filter: `created_ts > now() - 60 * 60 * 24`,
want: "memo.created_ts > $1",
args: []any{time.Now().Unix() - 60*60*24},
},
{
filter: `size(tags) == 0`,
want: "jsonb_array_length(COALESCE(memo.payload->'tags', '[]'::jsonb)) = $1",
args: []any{int64(0)},
},
{
filter: `size(tags) > 0`,
want: "jsonb_array_length(COALESCE(memo.payload->'tags', '[]'::jsonb)) > $1",
args: []any{int64(0)},
},
{
filter: `"work" in tags`,
want: "memo.payload->'tags' @> jsonb_build_array($1::json)",
args: []any{`"work"`},
},
{
filter: `size(tags) == 2`,
want: "jsonb_array_length(COALESCE(memo.payload->'tags', '[]'::jsonb)) = $1",
args: []any{int64(2)},
},
{
filter: `has_link == true`,
want: "(memo.payload->'property'->>'hasLink')::boolean = $1",
args: []any{true},
},
{
filter: `has_code == false`,
want: "(memo.payload->'property'->>'hasCode')::boolean = $1",
args: []any{false},
},
{
filter: `has_incomplete_tasks != false`,
want: "(memo.payload->'property'->>'hasIncompleteTasks')::boolean != $1",
args: []any{false},
},
{
filter: `has_link`,
want: "(memo.payload->'property'->>'hasLink')::boolean IS TRUE",
args: []any{},
},
{
filter: `has_code`,
want: "(memo.payload->'property'->>'hasCode')::boolean IS TRUE",
args: []any{},
},
{
filter: `has_incomplete_tasks`,
want: "(memo.payload->'property'->>'hasIncompleteTasks')::boolean IS TRUE",
args: []any{},
},
}
engine, err := filter.DefaultEngine()
require.NoError(t, err)
for _, tt := range tests {
stmt, err := engine.CompileToStatement(context.Background(), tt.filter, filter.RenderOptions{Dialect: filter.DialectPostgres})
require.NoError(t, err)
require.Equal(t, tt.want, stmt.SQL)
require.Equal(t, tt.args, stmt.Args)
}
}

View File

@ -17,6 +17,7 @@ func (d *DB) UpsertMemoRelation(ctx context.Context, create *store.MemoRelation)
type
)
VALUES (` + placeholders(3) + `)
ON CONFLICT (memo_id, related_memo_id, type) DO UPDATE SET type = EXCLUDED.type
RETURNING memo_id, related_memo_id, type
`
memoRelation := &store.MemoRelation{}

View File

@ -31,7 +31,7 @@ func (d *DB) CreateAttachment(ctx context.Context, create *store.Attachment) (*s
}
args := []any{create.UID, create.Filename, create.Blob, create.Type, create.Size, create.CreatorID, create.MemoID, storageType, create.Reference, payloadString}
stmt := "INSERT INTO `resource` (" + strings.Join(fields, ", ") + ") VALUES (" + strings.Join(placeholder, ", ") + ") RETURNING `id`, `created_ts`, `updated_ts`"
stmt := "INSERT INTO `attachment` (" + strings.Join(fields, ", ") + ") VALUES (" + strings.Join(placeholder, ", ") + ") RETURNING `id`, `created_ts`, `updated_ts`"
if err := d.db.QueryRowContext(ctx, stmt, args...).Scan(&create.ID, &create.CreatedTs, &create.UpdatedTs); err != nil {
return nil, err
}
@ -43,38 +43,38 @@ func (d *DB) ListAttachments(ctx context.Context, find *store.FindAttachment) ([
where, args := []string{"1 = 1"}, []any{}
if v := find.ID; v != nil {
where, args = append(where, "`resource`.`id` = ?"), append(args, *v)
where, args = append(where, "`attachment`.`id` = ?"), append(args, *v)
}
if v := find.UID; v != nil {
where, args = append(where, "`resource`.`uid` = ?"), append(args, *v)
where, args = append(where, "`attachment`.`uid` = ?"), append(args, *v)
}
if v := find.CreatorID; v != nil {
where, args = append(where, "`resource`.`creator_id` = ?"), append(args, *v)
where, args = append(where, "`attachment`.`creator_id` = ?"), append(args, *v)
}
if v := find.Filename; v != nil {
where, args = append(where, "`resource`.`filename` = ?"), append(args, *v)
where, args = append(where, "`attachment`.`filename` = ?"), append(args, *v)
}
if v := find.FilenameSearch; v != nil {
where, args = append(where, "`resource`.`filename` LIKE ?"), append(args, fmt.Sprintf("%%%s%%", *v))
where, args = append(where, "`attachment`.`filename` LIKE ?"), append(args, fmt.Sprintf("%%%s%%", *v))
}
if v := find.MemoID; v != nil {
where, args = append(where, "`resource`.`memo_id` = ?"), append(args, *v)
where, args = append(where, "`attachment`.`memo_id` = ?"), append(args, *v)
}
if len(find.MemoIDList) > 0 {
placeholders := make([]string, 0, len(find.MemoIDList))
for range find.MemoIDList {
placeholders = append(placeholders, "?")
}
where = append(where, "`resource`.`memo_id` IN ("+strings.Join(placeholders, ",")+")")
where = append(where, "`attachment`.`memo_id` IN ("+strings.Join(placeholders, ",")+")")
for _, id := range find.MemoIDList {
args = append(args, id)
}
}
if find.HasRelatedMemo {
where = append(where, "`resource`.`memo_id` IS NOT NULL")
where = append(where, "`attachment`.`memo_id` IS NOT NULL")
}
if find.StorageType != nil {
where, args = append(where, "`resource`.`storage_type` = ?"), append(args, find.StorageType.String())
where, args = append(where, "`attachment`.`storage_type` = ?"), append(args, find.StorageType.String())
}
if len(find.Filters) > 0 {
@ -88,28 +88,28 @@ func (d *DB) ListAttachments(ctx context.Context, find *store.FindAttachment) ([
}
fields := []string{
"`resource`.`id` AS `id`",
"`resource`.`uid` AS `uid`",
"`resource`.`filename` AS `filename`",
"`resource`.`type` AS `type`",
"`resource`.`size` AS `size`",
"`resource`.`creator_id` AS `creator_id`",
"`resource`.`created_ts` AS `created_ts`",
"`resource`.`updated_ts` AS `updated_ts`",
"`resource`.`memo_id` AS `memo_id`",
"`resource`.`storage_type` AS `storage_type`",
"`resource`.`reference` AS `reference`",
"`resource`.`payload` AS `payload`",
"`attachment`.`id` AS `id`",
"`attachment`.`uid` AS `uid`",
"`attachment`.`filename` AS `filename`",
"`attachment`.`type` AS `type`",
"`attachment`.`size` AS `size`",
"`attachment`.`creator_id` AS `creator_id`",
"`attachment`.`created_ts` AS `created_ts`",
"`attachment`.`updated_ts` AS `updated_ts`",
"`attachment`.`memo_id` AS `memo_id`",
"`attachment`.`storage_type` AS `storage_type`",
"`attachment`.`reference` AS `reference`",
"`attachment`.`payload` AS `payload`",
"CASE WHEN `memo`.`uid` IS NOT NULL THEN `memo`.`uid` ELSE NULL END AS `memo_uid`",
}
if find.GetBlob {
fields = append(fields, "`resource`.`blob` AS `blob`")
fields = append(fields, "`attachment`.`blob` AS `blob`")
}
query := "SELECT " + strings.Join(fields, ", ") + " FROM `resource`" + " " +
"LEFT JOIN `memo` ON `resource`.`memo_id` = `memo`.`id`" + " " +
query := "SELECT " + strings.Join(fields, ", ") + " FROM `attachment`" + " " +
"LEFT JOIN `memo` ON `attachment`.`memo_id` = `memo`.`id`" + " " +
"WHERE " + strings.Join(where, " AND ") + " " +
"ORDER BY `resource`.`updated_ts` DESC"
"ORDER BY `attachment`.`updated_ts` DESC"
if find.Limit != nil {
query = fmt.Sprintf("%s LIMIT %d", query, *find.Limit)
if find.Offset != nil {
@ -197,7 +197,7 @@ func (d *DB) UpdateAttachment(ctx context.Context, update *store.UpdateAttachmen
}
args = append(args, update.ID)
stmt := "UPDATE `resource` SET " + strings.Join(set, ", ") + " WHERE `id` = ?"
stmt := "UPDATE `attachment` SET " + strings.Join(set, ", ") + " WHERE `id` = ?"
result, err := d.db.ExecContext(ctx, stmt, args...)
if err != nil {
return errors.Wrap(err, "failed to update attachment")
@ -209,7 +209,7 @@ func (d *DB) UpdateAttachment(ctx context.Context, update *store.UpdateAttachmen
}
func (d *DB) DeleteAttachment(ctx context.Context, delete *store.DeleteAttachment) error {
stmt := "DELETE FROM `resource` WHERE `id` = ?"
stmt := "DELETE FROM `attachment` WHERE `id` = ?"
result, err := d.db.ExecContext(ctx, stmt, delete.ID)
if err != nil {
return err

View File

@ -55,7 +55,11 @@ func (d *DB) ListInboxes(ctx context.Context, find *store.FindInbox) ([]*store.I
if find.MessageType != nil {
// Filter by message type using JSON extraction
// Note: The type field in JSON is stored as string representation of the enum name
where, args = append(where, "JSON_EXTRACT(`message`, '$.type') = ?"), append(args, find.MessageType.String())
if *find.MessageType == storepb.InboxMessage_TYPE_UNSPECIFIED {
where, args = append(where, "(JSON_EXTRACT(`message`, '$.type') IS NULL OR JSON_EXTRACT(`message`, '$.type') = ?)"), append(args, find.MessageType.String())
} else {
where, args = append(where, "JSON_EXTRACT(`message`, '$.type') = ?"), append(args, find.MessageType.String())
}
}
query := "SELECT `id`, `created_ts`, `sender_id`, `receiver_id`, `status`, `message` FROM `inbox` WHERE " + strings.Join(where, " AND ") + " ORDER BY `created_ts` DESC"

View File

@ -26,6 +26,18 @@ func (d *DB) CreateMemo(ctx context.Context, create *store.Memo) (*store.Memo, e
}
args := []any{create.UID, create.CreatorID, create.Content, create.Visibility, payload}
// Add custom timestamps if provided
if create.CreatedTs != 0 {
fields = append(fields, "`created_ts`")
placeholder = append(placeholder, "?")
args = append(args, create.CreatedTs)
}
if create.UpdatedTs != 0 {
fields = append(fields, "`updated_ts`")
placeholder = append(placeholder, "?")
args = append(args, create.UpdatedTs)
}
stmt := "INSERT INTO `memo` (" + strings.Join(fields, ", ") + ") VALUES (" + strings.Join(placeholder, ", ") + ") RETURNING `id`, `created_ts`, `updated_ts`, `row_status`"
if err := d.db.QueryRowContext(ctx, stmt, args...).Scan(
&create.ID,

View File

@ -1,165 +0,0 @@
package sqlite
import (
"context"
"testing"
"time"
"github.com/stretchr/testify/require"
"github.com/usememos/memos/plugin/filter"
)
func TestConvertExprToSQL(t *testing.T) {
tests := []struct {
filter string
want string
args []any
}{
{
filter: `tag in ["tag1", "tag2"]`,
want: "((JSON_EXTRACT(`memo`.`payload`, '$.tags') LIKE ? OR JSON_EXTRACT(`memo`.`payload`, '$.tags') LIKE ?) OR (JSON_EXTRACT(`memo`.`payload`, '$.tags') LIKE ? OR JSON_EXTRACT(`memo`.`payload`, '$.tags') LIKE ?))",
args: []any{`%"tag1"%`, `%"tag1/%`, `%"tag2"%`, `%"tag2/%`},
},
{
filter: `!(tag in ["tag1", "tag2"])`,
want: "NOT (((JSON_EXTRACT(`memo`.`payload`, '$.tags') LIKE ? OR JSON_EXTRACT(`memo`.`payload`, '$.tags') LIKE ?) OR (JSON_EXTRACT(`memo`.`payload`, '$.tags') LIKE ? OR JSON_EXTRACT(`memo`.`payload`, '$.tags') LIKE ?)))",
args: []any{`%"tag1"%`, `%"tag1/%`, `%"tag2"%`, `%"tag2/%`},
},
{
filter: `content.contains("memos")`,
want: "`memo`.`content` LIKE ?",
args: []any{"%memos%"},
},
{
filter: `visibility in ["PUBLIC"]`,
want: "`memo`.`visibility` IN (?)",
args: []any{"PUBLIC"},
},
{
filter: `visibility in ["PUBLIC", "PRIVATE"]`,
want: "`memo`.`visibility` IN (?,?)",
args: []any{"PUBLIC", "PRIVATE"},
},
{
filter: `tag in ['tag1'] || content.contains('hello')`,
want: "((JSON_EXTRACT(`memo`.`payload`, '$.tags') LIKE ? OR JSON_EXTRACT(`memo`.`payload`, '$.tags') LIKE ?) OR `memo`.`content` LIKE ?)",
args: []any{`%"tag1"%`, `%"tag1/%`, "%hello%"},
},
{
filter: `1`,
want: "",
args: []any{},
},
{
filter: `pinned`,
want: "`memo`.`pinned` IS TRUE",
args: []any{},
},
{
filter: `has_task_list`,
want: "JSON_EXTRACT(`memo`.`payload`, '$.property.hasTaskList') IS TRUE",
args: []any{},
},
{
filter: `has_code`,
want: "JSON_EXTRACT(`memo`.`payload`, '$.property.hasCode') IS TRUE",
args: []any{},
},
{
filter: `has_task_list == true`,
want: "JSON_EXTRACT(`memo`.`payload`, '$.property.hasTaskList') = 1",
args: []any{},
},
{
filter: `has_task_list != false`,
want: "JSON_EXTRACT(`memo`.`payload`, '$.property.hasTaskList') != 0",
args: []any{},
},
{
filter: `has_task_list == false`,
want: "JSON_EXTRACT(`memo`.`payload`, '$.property.hasTaskList') = 0",
args: []any{},
},
{
filter: `!has_task_list`,
want: "NOT (JSON_EXTRACT(`memo`.`payload`, '$.property.hasTaskList') IS TRUE)",
args: []any{},
},
{
filter: `has_task_list && pinned`,
want: "(JSON_EXTRACT(`memo`.`payload`, '$.property.hasTaskList') IS TRUE AND `memo`.`pinned` IS TRUE)",
args: []any{},
},
{
filter: `has_task_list && content.contains("todo")`,
want: "(JSON_EXTRACT(`memo`.`payload`, '$.property.hasTaskList') IS TRUE AND `memo`.`content` LIKE ?)",
args: []any{"%todo%"},
},
{
filter: `created_ts > now() - 60 * 60 * 24`,
want: "`memo`.`created_ts` > ?",
args: []any{time.Now().Unix() - 60*60*24},
},
{
filter: `size(tags) == 0`,
want: "JSON_ARRAY_LENGTH(COALESCE(JSON_EXTRACT(`memo`.`payload`, '$.tags'), JSON_ARRAY())) = ?",
args: []any{int64(0)},
},
{
filter: `size(tags) > 0`,
want: "JSON_ARRAY_LENGTH(COALESCE(JSON_EXTRACT(`memo`.`payload`, '$.tags'), JSON_ARRAY())) > ?",
args: []any{int64(0)},
},
{
filter: `"work" in tags`,
want: "JSON_EXTRACT(`memo`.`payload`, '$.tags') LIKE ?",
args: []any{`%"work"%`},
},
{
filter: `size(tags) == 2`,
want: "JSON_ARRAY_LENGTH(COALESCE(JSON_EXTRACT(`memo`.`payload`, '$.tags'), JSON_ARRAY())) = ?",
args: []any{int64(2)},
},
{
filter: `has_link == true`,
want: "JSON_EXTRACT(`memo`.`payload`, '$.property.hasLink') IS TRUE",
args: []any{},
},
{
filter: `has_code == false`,
want: "NOT(JSON_EXTRACT(`memo`.`payload`, '$.property.hasCode') IS TRUE)",
args: []any{},
},
{
filter: `has_incomplete_tasks != false`,
want: "JSON_EXTRACT(`memo`.`payload`, '$.property.hasIncompleteTasks') IS TRUE",
args: []any{},
},
{
filter: `has_link`,
want: "JSON_EXTRACT(`memo`.`payload`, '$.property.hasLink') IS TRUE",
args: []any{},
},
{
filter: `has_code`,
want: "JSON_EXTRACT(`memo`.`payload`, '$.property.hasCode') IS TRUE",
args: []any{},
},
{
filter: `has_incomplete_tasks`,
want: "JSON_EXTRACT(`memo`.`payload`, '$.property.hasIncompleteTasks') IS TRUE",
args: []any{},
},
}
engine, err := filter.DefaultEngine()
require.NoError(t, err)
for _, tt := range tests {
stmt, err := engine.CompileToStatement(context.Background(), tt.filter, filter.RenderOptions{Dialect: filter.DialectSQLite})
require.NoError(t, err)
require.Equal(t, tt.want, stmt.SQL)
require.Equal(t, tt.args, stmt.Args)
}
}

View File

@ -17,6 +17,7 @@ func (d *DB) UpsertMemoRelation(ctx context.Context, create *store.MemoRelation)
type
)
VALUES (?, ?, ?)
ON CONFLICT(memo_id, related_memo_id, type) DO UPDATE SET type = excluded.type
RETURNING memo_id, related_memo_id, type
`
memoRelation := &store.MemoRelation{}

View File

@ -33,6 +33,7 @@ func NewDB(profile *profile.Profile) (store.Driver, error) {
// good practice to be explicit and prevent future surprises on SQLite upgrades.
// - Journal mode set to WAL: it's the recommended journal mode for most applications
// as it prevents locking issues.
// - mmap size set to 0: it disables memory mapping, which can cause OOM errors on some systems.
//
// Notes:
// - When using the `modernc.org/sqlite` driver, each pragma must be prefixed with `_pragma=`.
@ -41,7 +42,7 @@ func NewDB(profile *profile.Profile) (store.Driver, error) {
// - https://pkg.go.dev/modernc.org/sqlite#Driver.Open
// - https://www.sqlite.org/sharedcache.html
// - https://www.sqlite.org/pragma.html
sqliteDB, err := sql.Open("sqlite", profile.DSN+"?_pragma=foreign_keys(0)&_pragma=busy_timeout(10000)&_pragma=journal_mode(WAL)")
sqliteDB, err := sql.Open("sqlite", profile.DSN+"?_pragma=foreign_keys(0)&_pragma=busy_timeout(10000)&_pragma=journal_mode(WAL)&_pragma=mmap_size(0)")
if err != nil {
return nil, errors.Wrapf(err, "failed to open db with dsn: %s", profile.DSN)
}

View File

@ -138,5 +138,22 @@ func (s *Store) UpdateMemo(ctx context.Context, update *UpdateMemo) error {
}
func (s *Store) DeleteMemo(ctx context.Context, delete *DeleteMemo) error {
// Clean up memo_relation records where this memo is either the source or target.
if err := s.driver.DeleteMemoRelation(ctx, &DeleteMemoRelation{MemoID: &delete.ID}); err != nil {
return err
}
if err := s.driver.DeleteMemoRelation(ctx, &DeleteMemoRelation{RelatedMemoID: &delete.ID}); err != nil {
return err
}
// Clean up attachments linked to this memo.
attachments, err := s.ListAttachments(ctx, &FindAttachment{MemoID: &delete.ID})
if err != nil {
return err
}
for _, attachment := range attachments {
if err := s.DeleteAttachment(ctx, &DeleteAttachment{ID: attachment.ID}); err != nil {
return err
}
}
return s.driver.DeleteMemo(ctx, delete)
}

View File

@ -0,0 +1 @@
RENAME TABLE resource TO attachment;

View File

@ -0,0 +1 @@
DROP TABLE IF EXISTS memo_organizer;

View File

@ -0,0 +1 @@
UPDATE `user` SET `role` = 'ADMIN' WHERE `role` = 'HOST';

View File

@ -42,14 +42,6 @@ CREATE TABLE `memo` (
`payload` JSON NOT NULL
);
-- memo_organizer
CREATE TABLE `memo_organizer` (
`memo_id` INT NOT NULL,
`user_id` INT NOT NULL,
`pinned` INT NOT NULL DEFAULT '0',
UNIQUE(`memo_id`,`user_id`)
);
-- memo_relation
CREATE TABLE `memo_relation` (
`memo_id` INT NOT NULL,
@ -58,8 +50,8 @@ CREATE TABLE `memo_relation` (
UNIQUE(`memo_id`,`related_memo_id`,`type`)
);
-- resource
CREATE TABLE `resource` (
-- attachment
CREATE TABLE `attachment` (
`id` INT NOT NULL AUTO_INCREMENT PRIMARY KEY,
`uid` VARCHAR(256) NOT NULL UNIQUE,
`creator_id` INT NOT NULL,

View File

@ -0,0 +1 @@
ALTER TABLE resource RENAME TO attachment;

View File

@ -0,0 +1 @@
DROP TABLE IF EXISTS memo_organizer;

View File

@ -0,0 +1 @@
UPDATE "user" SET role = 'ADMIN' WHERE role = 'HOST';

View File

@ -42,14 +42,6 @@ CREATE TABLE memo (
payload JSONB NOT NULL DEFAULT '{}'
);
-- memo_organizer
CREATE TABLE memo_organizer (
memo_id INTEGER NOT NULL,
user_id INTEGER NOT NULL,
pinned INTEGER NOT NULL DEFAULT 0,
UNIQUE(memo_id, user_id)
);
-- memo_relation
CREATE TABLE memo_relation (
memo_id INTEGER NOT NULL,
@ -58,8 +50,8 @@ CREATE TABLE memo_relation (
UNIQUE(memo_id, related_memo_id, type)
);
-- resource
CREATE TABLE resource (
-- attachment
CREATE TABLE attachment (
id SERIAL PRIMARY KEY,
uid TEXT NOT NULL UNIQUE,
creator_id INTEGER NOT NULL,

View File

@ -0,0 +1,5 @@
ALTER TABLE `resource` RENAME TO `attachment`;
DROP INDEX IF EXISTS `idx_resource_creator_id`;
CREATE INDEX `idx_attachment_creator_id` ON `attachment` (`creator_id`);
DROP INDEX IF EXISTS `idx_resource_memo_id`;
CREATE INDEX `idx_attachment_memo_id` ON `attachment` (`memo_id`);

View File

@ -0,0 +1 @@
DROP TABLE IF EXISTS memo_organizer;

View File

@ -0,0 +1,4 @@
DROP INDEX IF EXISTS idx_user_username;
DROP INDEX IF EXISTS idx_memo_creator_id;
DROP INDEX IF EXISTS idx_attachment_creator_id;
DROP INDEX IF EXISTS idx_attachment_memo_id;

View File

@ -0,0 +1,24 @@
ALTER TABLE user RENAME TO user_old;
CREATE TABLE user (
id INTEGER PRIMARY KEY AUTOINCREMENT,
created_ts BIGINT NOT NULL DEFAULT (strftime('%s', 'now')),
updated_ts BIGINT NOT NULL DEFAULT (strftime('%s', 'now')),
row_status TEXT NOT NULL CHECK (row_status IN ('NORMAL', 'ARCHIVED')) DEFAULT 'NORMAL',
username TEXT NOT NULL UNIQUE,
role TEXT NOT NULL DEFAULT 'USER',
email TEXT NOT NULL DEFAULT '',
nickname TEXT NOT NULL DEFAULT '',
password_hash TEXT NOT NULL,
avatar_url TEXT NOT NULL DEFAULT '',
description TEXT NOT NULL DEFAULT ''
);
INSERT INTO user (
id, created_ts, updated_ts, row_status, username, role, email, nickname, password_hash, avatar_url, description
)
SELECT
id, created_ts, updated_ts, row_status, username, role, email, nickname, password_hash, avatar_url, description
FROM user_old;
DROP TABLE user_old;

View File

@ -0,0 +1 @@
UPDATE user SET role = 'ADMIN' WHERE role = 'HOST';

View File

@ -13,7 +13,7 @@ CREATE TABLE user (
updated_ts BIGINT NOT NULL DEFAULT (strftime('%s', 'now')),
row_status TEXT NOT NULL CHECK (row_status IN ('NORMAL', 'ARCHIVED')) DEFAULT 'NORMAL',
username TEXT NOT NULL UNIQUE,
role TEXT NOT NULL CHECK (role IN ('HOST', 'ADMIN', 'USER')) DEFAULT 'USER',
role TEXT NOT NULL DEFAULT 'USER',
email TEXT NOT NULL DEFAULT '',
nickname TEXT NOT NULL DEFAULT '',
password_hash TEXT NOT NULL,
@ -21,8 +21,6 @@ CREATE TABLE user (
description TEXT NOT NULL DEFAULT ''
);
CREATE INDEX idx_user_username ON user (username);
-- user_setting
CREATE TABLE user_setting (
user_id INTEGER NOT NULL,
@ -45,16 +43,6 @@ CREATE TABLE memo (
payload TEXT NOT NULL DEFAULT '{}'
);
CREATE INDEX idx_memo_creator_id ON memo (creator_id);
-- memo_organizer
CREATE TABLE memo_organizer (
memo_id INTEGER NOT NULL,
user_id INTEGER NOT NULL,
pinned INTEGER NOT NULL CHECK (pinned IN (0, 1)) DEFAULT 0,
UNIQUE(memo_id, user_id)
);
-- memo_relation
CREATE TABLE memo_relation (
memo_id INTEGER NOT NULL,
@ -63,8 +51,8 @@ CREATE TABLE memo_relation (
UNIQUE(memo_id, related_memo_id, type)
);
-- resource
CREATE TABLE resource (
-- attachment
CREATE TABLE attachment (
id INTEGER PRIMARY KEY AUTOINCREMENT,
uid TEXT NOT NULL UNIQUE,
creator_id INTEGER NOT NULL,
@ -80,10 +68,6 @@ CREATE TABLE resource (
payload TEXT NOT NULL DEFAULT '{}'
);
CREATE INDEX idx_resource_creator_id ON resource (creator_id);
CREATE INDEX idx_resource_memo_id ON resource (memo_id);
-- activity
CREATE TABLE activity (
id INTEGER PRIMARY KEY AUTOINCREMENT,

View File

@ -21,7 +21,7 @@ import (
// Migration System Overview:
//
// The migration system handles database schema versioning and upgrades.
// Schema version is stored in instance_setting (formerly system_setting).
// Schema version is stored in system_setting.
//
// Migration Flow:
// 1. preMigrate: Check if DB is initialized. If not, apply LATEST.sql
@ -30,9 +30,9 @@ import (
// 4. Migrate (demo mode): Seed database with demo data
//
// Version Tracking:
// - New installations: Schema version set in instance_setting immediately
// - Existing v0.22+ installations: Schema version tracked in instance_setting
// - Pre-v0.22 installations: Must upgrade to v0.25.x first (migration_history → instance_setting migration)
// - New installations: Schema version set in system_setting immediately
// - Existing v0.22+ installations: Schema version tracked in system_setting
// - Pre-v0.22 installations: Must upgrade to v0.25.x first (migration_history → system_setting migration)
//
// Migration Files:
// - Location: store/migration/{driver}/{version}/NN__description.sql
@ -57,10 +57,6 @@ const (
// defaultSchemaVersion is used when schema version is empty or not set.
// This handles edge cases for old installations without version tracking.
defaultSchemaVersion = "0.0.0"
// Mode constants for profile mode.
modeProd = "prod"
modeDemo = "demo"
)
// getSchemaVersionOrDefault returns the schema version or default if empty.
@ -110,38 +106,36 @@ func (s *Store) Migrate(ctx context.Context) error {
return errors.Wrap(err, "failed to pre-migrate")
}
switch s.profile.Mode {
case modeProd:
instanceBasicSetting, err := s.GetInstanceBasicSetting(ctx)
if err != nil {
return errors.Wrap(err, "failed to get instance basic setting")
instanceBasicSetting, err := s.GetInstanceBasicSetting(ctx)
if err != nil {
return errors.Wrap(err, "failed to get instance basic setting")
}
currentSchemaVersion, err := s.GetCurrentSchemaVersion()
if err != nil {
return errors.Wrap(err, "failed to get current schema version")
}
// Check for downgrade (but skip if schema version is empty - that means fresh/old installation)
if !isVersionEmpty(instanceBasicSetting.SchemaVersion) && version.IsVersionGreaterThan(instanceBasicSetting.SchemaVersion, currentSchemaVersion) {
slog.Error("cannot downgrade schema version",
slog.String("databaseVersion", instanceBasicSetting.SchemaVersion),
slog.String("currentVersion", currentSchemaVersion),
)
return errors.Errorf("cannot downgrade schema version from %s to %s", instanceBasicSetting.SchemaVersion, currentSchemaVersion)
}
// Apply migrations if needed (including when schema version is empty)
if isVersionEmpty(instanceBasicSetting.SchemaVersion) || version.IsVersionGreaterThan(currentSchemaVersion, instanceBasicSetting.SchemaVersion) {
if err := s.applyMigrations(ctx, instanceBasicSetting.SchemaVersion, currentSchemaVersion); err != nil {
return errors.Wrap(err, "failed to apply migrations")
}
currentSchemaVersion, err := s.GetCurrentSchemaVersion()
if err != nil {
return errors.Wrap(err, "failed to get current schema version")
}
// Check for downgrade (but skip if schema version is empty - that means fresh/old installation)
if !isVersionEmpty(instanceBasicSetting.SchemaVersion) && version.IsVersionGreaterThan(instanceBasicSetting.SchemaVersion, currentSchemaVersion) {
slog.Error("cannot downgrade schema version",
slog.String("databaseVersion", instanceBasicSetting.SchemaVersion),
slog.String("currentVersion", currentSchemaVersion),
)
return errors.Errorf("cannot downgrade schema version from %s to %s", instanceBasicSetting.SchemaVersion, currentSchemaVersion)
}
// Apply migrations if needed (including when schema version is empty)
if isVersionEmpty(instanceBasicSetting.SchemaVersion) || version.IsVersionGreaterThan(currentSchemaVersion, instanceBasicSetting.SchemaVersion) {
if err := s.applyMigrations(ctx, instanceBasicSetting.SchemaVersion, currentSchemaVersion); err != nil {
return errors.Wrap(err, "failed to apply migrations")
}
}
case modeDemo:
}
if s.profile.Demo {
// In demo mode, we should seed the database.
if err := s.seed(ctx); err != nil {
return errors.Wrap(err, "failed to seed")
}
default:
// For other modes (like dev), no special migration handling needed
}
return nil
}
@ -255,14 +249,11 @@ func (s *Store) preMigrate(ctx context.Context) error {
}
}
if s.profile.Mode == modeProd {
if err := s.checkMinimumUpgradeVersion(ctx); err != nil {
return err // Error message is already descriptive, don't wrap it
}
if err := s.checkMinimumUpgradeVersion(ctx); err != nil {
return err // Error message is already descriptive, don't wrap it
}
return nil
}
func (s *Store) getMigrationBasePath() string {
return fmt.Sprintf("migration/%s/", s.profile.Driver)
}
@ -308,7 +299,7 @@ func (s *Store) seed(ctx context.Context) error {
}
func (s *Store) GetCurrentSchemaVersion() (string, error) {
currentVersion := version.GetCurrentVersion(s.profile.Mode)
currentVersion := version.GetCurrentVersion()
minorVersion := version.GetMinorVersion(currentVersion)
filePaths, err := fs.Glob(migrationFS, fmt.Sprintf("%s%s/*.sql", s.getMigrationBasePath(), minorVersion))
if err != nil {
@ -373,7 +364,7 @@ func (s *Store) updateCurrentSchemaVersion(ctx context.Context, schemaVersion st
// checkMinimumUpgradeVersion verifies the installation meets minimum version requirements for upgrade.
// For very old installations (< v0.22.0), users must upgrade to v0.25.x first before upgrading to current version.
// This is necessary because schema version tracking was moved from migration_history to instance_setting in v0.22.0.
// This is necessary because schema version tracking was moved from migration_history to system_setting in v0.22.0.
func (s *Store) checkMinimumUpgradeVersion(ctx context.Context) error {
instanceBasicSetting, err := s.GetInstanceBasicSetting(ctx)
if err != nil {
@ -401,7 +392,7 @@ func (s *Store) checkMinimumUpgradeVersion(ctx context.Context) error {
"2. Start the server and verify it works\n"+
"3. Then upgrade to the latest version\n\n"+
"This is required because schema version tracking was moved from migration_history\n"+
"to instance_setting in v0.22.0. The intermediate upgrade handles this migration safely.",
"to system_setting in v0.22.0. The intermediate upgrade handles this migration safely.",
schemaVersion,
currentVersion,
)

View File

@ -10,7 +10,7 @@ The demo data includes **6 carefully selected memos** that showcase the key feat
- **Username**: `demo`
- **Password**: `secret` (default password)
- **Role**: HOST
- **Role**: ADMIN
- **Nickname**: Demo User
## Demo Memos (6 total)
@ -174,10 +174,10 @@ To run with demo data:
```bash
# Start in demo mode
go run ./cmd/memos --mode demo --port 8081
go run ./cmd/memos --demo --port 8081
# Or use the binary
./memos --mode demo
./memos --demo
# Demo database location
./build/memos_demo.db
@ -198,7 +198,7 @@ Login with:
- All memos are set to PUBLIC visibility
- **Two memos are pinned**: Welcome (#1) and Sponsor (#6)
- User has HOST role to showcase all features
- User has ADMIN role to showcase all features
- Reactions are distributed across memos
- One memo relation demonstrates linking
- Content is optimized for the compact markdown styles

View File

@ -1,11 +0,0 @@
DELETE FROM system_setting;
DELETE FROM user;
DELETE FROM user_setting;
DELETE FROM memo;
DELETE FROM memo_organizer;
DELETE FROM memo_relation;
DELETE FROM resource;
DELETE FROM activity;
DELETE FROM idp;
DELETE FROM inbox;
DELETE FROM reaction;

View File

@ -1,5 +1,5 @@
-- Demo User
INSERT INTO user (id,username,role,nickname,password_hash) VALUES(1,'demo','HOST','Demo User','$2a$10$c.slEVgf5b/3BnAWlLb/vOu7VVSOKJ4ljwMe9xzlx9IhKnvAsJYM6');
INSERT INTO user (id,username,role,nickname,password_hash) VALUES(1,'demo','ADMIN','Demo User','$2a$10$c.slEVgf5b/3BnAWlLb/vOu7VVSOKJ4ljwMe9xzlx9IhKnvAsJYM6');
-- Welcome Memo (Pinned)
INSERT INTO memo (id,uid,creator_id,content,visibility,pinned,payload) VALUES(1,'welcome2memos001',1,replace('# Welcome to Memos!\\n\\nA privacy-first, lightweight note-taking service. Easily capture and share your great thoughts.\\n\\n## Key Features\\n\\n- **Privacy First**: Your data stays with you\\n- **Markdown Support**: Full CommonMark + GFM syntax\\n- **Quick Capture**: Jot down thoughts instantly\\n- **Organize with Tags**: Use #tags to categorize\\n- **Open Source**: Free and open source software\\n\\n---\\n\\nStart exploring the demo memos below to see what you can do! #welcome #getting-started','\\n',char(10)),'PUBLIC',1,'{"tags":["welcome","getting-started"],"property":{"hasLink":false}}');

View File

@ -11,6 +11,7 @@ import (
)
func TestActivityStore(t *testing.T) {
t.Parallel()
ctx := context.Background()
ts := NewTestingStore(ctx, t)
user, err := createTestingHostUser(ctx, ts)
@ -34,6 +35,7 @@ func TestActivityStore(t *testing.T) {
}
func TestActivityGetByID(t *testing.T) {
t.Parallel()
ctx := context.Background()
ts := NewTestingStore(ctx, t)
user, err := createTestingHostUser(ctx, ts)
@ -63,6 +65,7 @@ func TestActivityGetByID(t *testing.T) {
}
func TestActivityListMultiple(t *testing.T) {
t.Parallel()
ctx := context.Background()
ts := NewTestingStore(ctx, t)
user, err := createTestingHostUser(ctx, ts)
@ -99,3 +102,279 @@ func TestActivityListMultiple(t *testing.T) {
ts.Close()
}
func TestActivityListByType(t *testing.T) {
t.Parallel()
ctx := context.Background()
ts := NewTestingStore(ctx, t)
user, err := createTestingHostUser(ctx, ts)
require.NoError(t, err)
// Create activities with MEMO_COMMENT type
_, err = ts.CreateActivity(ctx, &store.Activity{
CreatorID: user.ID,
Type: store.ActivityTypeMemoComment,
Level: store.ActivityLevelInfo,
Payload: &storepb.ActivityPayload{},
})
require.NoError(t, err)
_, err = ts.CreateActivity(ctx, &store.Activity{
CreatorID: user.ID,
Type: store.ActivityTypeMemoComment,
Level: store.ActivityLevelInfo,
Payload: &storepb.ActivityPayload{},
})
require.NoError(t, err)
// List by type
activityType := store.ActivityTypeMemoComment
activities, err := ts.ListActivities(ctx, &store.FindActivity{Type: &activityType})
require.NoError(t, err)
require.Len(t, activities, 2)
for _, activity := range activities {
require.Equal(t, store.ActivityTypeMemoComment, activity.Type)
}
ts.Close()
}
func TestActivityPayloadMemoComment(t *testing.T) {
t.Parallel()
ctx := context.Background()
ts := NewTestingStore(ctx, t)
user, err := createTestingHostUser(ctx, ts)
require.NoError(t, err)
// Create activity with MemoComment payload
memoID := int32(123)
relatedMemoID := int32(456)
activity, err := ts.CreateActivity(ctx, &store.Activity{
CreatorID: user.ID,
Type: store.ActivityTypeMemoComment,
Level: store.ActivityLevelInfo,
Payload: &storepb.ActivityPayload{
MemoComment: &storepb.ActivityMemoCommentPayload{
MemoId: memoID,
RelatedMemoId: relatedMemoID,
},
},
})
require.NoError(t, err)
require.NotNil(t, activity.Payload)
require.NotNil(t, activity.Payload.MemoComment)
require.Equal(t, memoID, activity.Payload.MemoComment.MemoId)
require.Equal(t, relatedMemoID, activity.Payload.MemoComment.RelatedMemoId)
// Verify payload is preserved when listing
found, err := ts.GetActivity(ctx, &store.FindActivity{ID: &activity.ID})
require.NoError(t, err)
require.NotNil(t, found.Payload.MemoComment)
require.Equal(t, memoID, found.Payload.MemoComment.MemoId)
require.Equal(t, relatedMemoID, found.Payload.MemoComment.RelatedMemoId)
ts.Close()
}
func TestActivityEmptyPayload(t *testing.T) {
t.Parallel()
ctx := context.Background()
ts := NewTestingStore(ctx, t)
user, err := createTestingHostUser(ctx, ts)
require.NoError(t, err)
// Create activity with empty payload
activity, err := ts.CreateActivity(ctx, &store.Activity{
CreatorID: user.ID,
Type: store.ActivityTypeMemoComment,
Level: store.ActivityLevelInfo,
Payload: &storepb.ActivityPayload{},
})
require.NoError(t, err)
require.NotNil(t, activity.Payload)
// Verify empty payload is handled correctly
found, err := ts.GetActivity(ctx, &store.FindActivity{ID: &activity.ID})
require.NoError(t, err)
require.NotNil(t, found.Payload)
require.Nil(t, found.Payload.MemoComment)
ts.Close()
}
func TestActivityLevel(t *testing.T) {
t.Parallel()
ctx := context.Background()
ts := NewTestingStore(ctx, t)
user, err := createTestingHostUser(ctx, ts)
require.NoError(t, err)
// Create activity with INFO level
activity, err := ts.CreateActivity(ctx, &store.Activity{
CreatorID: user.ID,
Type: store.ActivityTypeMemoComment,
Level: store.ActivityLevelInfo,
Payload: &storepb.ActivityPayload{},
})
require.NoError(t, err)
require.Equal(t, store.ActivityLevelInfo, activity.Level)
// Verify level is preserved when listing
found, err := ts.GetActivity(ctx, &store.FindActivity{ID: &activity.ID})
require.NoError(t, err)
require.Equal(t, store.ActivityLevelInfo, found.Level)
ts.Close()
}
func TestActivityCreatorID(t *testing.T) {
t.Parallel()
ctx := context.Background()
ts := NewTestingStore(ctx, t)
user1, err := createTestingHostUser(ctx, ts)
require.NoError(t, err)
user2, err := createTestingUserWithRole(ctx, ts, "user2", store.RoleUser)
require.NoError(t, err)
// Create activity for user1
activity1, err := ts.CreateActivity(ctx, &store.Activity{
CreatorID: user1.ID,
Type: store.ActivityTypeMemoComment,
Level: store.ActivityLevelInfo,
Payload: &storepb.ActivityPayload{},
})
require.NoError(t, err)
require.Equal(t, user1.ID, activity1.CreatorID)
// Create activity for user2
activity2, err := ts.CreateActivity(ctx, &store.Activity{
CreatorID: user2.ID,
Type: store.ActivityTypeMemoComment,
Level: store.ActivityLevelInfo,
Payload: &storepb.ActivityPayload{},
})
require.NoError(t, err)
require.Equal(t, user2.ID, activity2.CreatorID)
// List all and verify creator IDs
activities, err := ts.ListActivities(ctx, &store.FindActivity{})
require.NoError(t, err)
require.Len(t, activities, 2)
ts.Close()
}
func TestActivityCreatedTs(t *testing.T) {
t.Parallel()
ctx := context.Background()
ts := NewTestingStore(ctx, t)
user, err := createTestingHostUser(ctx, ts)
require.NoError(t, err)
activity, err := ts.CreateActivity(ctx, &store.Activity{
CreatorID: user.ID,
Type: store.ActivityTypeMemoComment,
Level: store.ActivityLevelInfo,
Payload: &storepb.ActivityPayload{},
})
require.NoError(t, err)
require.NotZero(t, activity.CreatedTs)
// Verify timestamp is preserved when listing
found, err := ts.GetActivity(ctx, &store.FindActivity{ID: &activity.ID})
require.NoError(t, err)
require.Equal(t, activity.CreatedTs, found.CreatedTs)
ts.Close()
}
func TestActivityListEmpty(t *testing.T) {
t.Parallel()
ctx := context.Background()
ts := NewTestingStore(ctx, t)
// List activities when none exist
activities, err := ts.ListActivities(ctx, &store.FindActivity{})
require.NoError(t, err)
require.Len(t, activities, 0)
ts.Close()
}
func TestActivityListWithIDAndType(t *testing.T) {
t.Parallel()
ctx := context.Background()
ts := NewTestingStore(ctx, t)
user, err := createTestingHostUser(ctx, ts)
require.NoError(t, err)
activity, err := ts.CreateActivity(ctx, &store.Activity{
CreatorID: user.ID,
Type: store.ActivityTypeMemoComment,
Level: store.ActivityLevelInfo,
Payload: &storepb.ActivityPayload{},
})
require.NoError(t, err)
// List with both ID and Type filters
activityType := store.ActivityTypeMemoComment
activities, err := ts.ListActivities(ctx, &store.FindActivity{
ID: &activity.ID,
Type: &activityType,
})
require.NoError(t, err)
require.Len(t, activities, 1)
require.Equal(t, activity.ID, activities[0].ID)
ts.Close()
}
func TestActivityPayloadComplexMemoComment(t *testing.T) {
t.Parallel()
ctx := context.Background()
ts := NewTestingStore(ctx, t)
user, err := createTestingHostUser(ctx, ts)
require.NoError(t, err)
// Create a memo first to use its ID
memo, err := ts.CreateMemo(ctx, &store.Memo{
UID: "test-memo-for-activity",
CreatorID: user.ID,
Content: "Test memo content",
Visibility: store.Public,
})
require.NoError(t, err)
// Create comment memo
commentMemo, err := ts.CreateMemo(ctx, &store.Memo{
UID: "comment-memo",
CreatorID: user.ID,
Content: "This is a comment",
Visibility: store.Public,
})
require.NoError(t, err)
// Create activity with real memo IDs
activity, err := ts.CreateActivity(ctx, &store.Activity{
CreatorID: user.ID,
Type: store.ActivityTypeMemoComment,
Level: store.ActivityLevelInfo,
Payload: &storepb.ActivityPayload{
MemoComment: &storepb.ActivityMemoCommentPayload{
MemoId: memo.ID,
RelatedMemoId: commentMemo.ID,
},
},
})
require.NoError(t, err)
require.Equal(t, memo.ID, activity.Payload.MemoComment.MemoId)
require.Equal(t, commentMemo.ID, activity.Payload.MemoComment.RelatedMemoId)
// Verify payload is preserved
found, err := ts.GetActivity(ctx, &store.FindActivity{ID: &activity.ID})
require.NoError(t, err)
require.Equal(t, memo.ID, found.Payload.MemoComment.MemoId)
require.Equal(t, commentMemo.ID, found.Payload.MemoComment.RelatedMemoId)
ts.Close()
}

View File

@ -13,6 +13,7 @@ import (
// =============================================================================
func TestAttachmentFilterFilenameContains(t *testing.T) {
t.Parallel()
tc := NewAttachmentFilterTestContext(t)
defer tc.Close()
@ -35,6 +36,7 @@ func TestAttachmentFilterFilenameContains(t *testing.T) {
}
func TestAttachmentFilterFilenameSpecialCharacters(t *testing.T) {
t.Parallel()
tc := NewAttachmentFilterTestContext(t)
defer tc.Close()
@ -51,6 +53,7 @@ func TestAttachmentFilterFilenameSpecialCharacters(t *testing.T) {
}
func TestAttachmentFilterFilenameUnicode(t *testing.T) {
t.Parallel()
tc := NewAttachmentFilterTestContext(t)
defer tc.Close()
@ -67,6 +70,7 @@ func TestAttachmentFilterFilenameUnicode(t *testing.T) {
// =============================================================================
func TestAttachmentFilterMimeTypeEquals(t *testing.T) {
t.Parallel()
tc := NewAttachmentFilterTestContext(t)
defer tc.Close()
@ -86,6 +90,7 @@ func TestAttachmentFilterMimeTypeEquals(t *testing.T) {
}
func TestAttachmentFilterMimeTypeNotEquals(t *testing.T) {
t.Parallel()
tc := NewAttachmentFilterTestContext(t)
defer tc.Close()
@ -98,6 +103,7 @@ func TestAttachmentFilterMimeTypeNotEquals(t *testing.T) {
}
func TestAttachmentFilterMimeTypeInList(t *testing.T) {
t.Parallel()
tc := NewAttachmentFilterTestContext(t)
defer tc.Close()
@ -121,6 +127,7 @@ func TestAttachmentFilterMimeTypeInList(t *testing.T) {
// =============================================================================
func TestAttachmentFilterCreateTimeComparison(t *testing.T) {
t.Parallel()
tc := NewAttachmentFilterTestContext(t)
defer tc.Close()
@ -141,6 +148,7 @@ func TestAttachmentFilterCreateTimeComparison(t *testing.T) {
}
func TestAttachmentFilterCreateTimeWithNow(t *testing.T) {
t.Parallel()
tc := NewAttachmentFilterTestContext(t)
defer tc.Close()
@ -156,6 +164,7 @@ func TestAttachmentFilterCreateTimeWithNow(t *testing.T) {
}
func TestAttachmentFilterCreateTimeArithmetic(t *testing.T) {
t.Parallel()
tc := NewAttachmentFilterTestContext(t)
defer tc.Close()
@ -175,6 +184,7 @@ func TestAttachmentFilterCreateTimeArithmetic(t *testing.T) {
}
func TestAttachmentFilterAllComparisonOperators(t *testing.T) {
t.Parallel()
tc := NewAttachmentFilterTestContext(t)
defer tc.Close()
@ -203,6 +213,7 @@ func TestAttachmentFilterAllComparisonOperators(t *testing.T) {
// =============================================================================
func TestAttachmentFilterMemoIdEquals(t *testing.T) {
t.Parallel()
tc := NewAttachmentFilterTestContextWithUser(t)
defer tc.Close()
@ -218,6 +229,7 @@ func TestAttachmentFilterMemoIdEquals(t *testing.T) {
}
func TestAttachmentFilterMemoIdNotEquals(t *testing.T) {
t.Parallel()
tc := NewAttachmentFilterTestContextWithUser(t)
defer tc.Close()
@ -238,6 +250,7 @@ func TestAttachmentFilterMemoIdNotEquals(t *testing.T) {
// =============================================================================
func TestAttachmentFilterLogicalAnd(t *testing.T) {
t.Parallel()
tc := NewAttachmentFilterTestContext(t)
defer tc.Close()
@ -251,6 +264,7 @@ func TestAttachmentFilterLogicalAnd(t *testing.T) {
}
func TestAttachmentFilterLogicalOr(t *testing.T) {
t.Parallel()
tc := NewAttachmentFilterTestContext(t)
defer tc.Close()
@ -263,6 +277,7 @@ func TestAttachmentFilterLogicalOr(t *testing.T) {
}
func TestAttachmentFilterLogicalNot(t *testing.T) {
t.Parallel()
tc := NewAttachmentFilterTestContext(t)
defer tc.Close()
@ -275,6 +290,7 @@ func TestAttachmentFilterLogicalNot(t *testing.T) {
}
func TestAttachmentFilterComplexLogical(t *testing.T) {
t.Parallel()
tc := NewAttachmentFilterTestContext(t)
defer tc.Close()
@ -291,6 +307,7 @@ func TestAttachmentFilterComplexLogical(t *testing.T) {
// =============================================================================
func TestAttachmentFilterMultipleFilters(t *testing.T) {
t.Parallel()
tc := NewAttachmentFilterTestContext(t)
defer tc.Close()
@ -310,6 +327,7 @@ func TestAttachmentFilterMultipleFilters(t *testing.T) {
// =============================================================================
func TestAttachmentFilterNoMatches(t *testing.T) {
t.Parallel()
tc := NewAttachmentFilterTestContext(t)
defer tc.Close()
@ -320,6 +338,7 @@ func TestAttachmentFilterNoMatches(t *testing.T) {
}
func TestAttachmentFilterNullMemoId(t *testing.T) {
t.Parallel()
tc := NewAttachmentFilterTestContextWithUser(t)
defer tc.Close()
@ -343,6 +362,7 @@ func TestAttachmentFilterNullMemoId(t *testing.T) {
}
func TestAttachmentFilterEmptyFilename(t *testing.T) {
t.Parallel()
tc := NewAttachmentFilterTestContext(t)
defer tc.Close()

View File

@ -12,6 +12,7 @@ import (
)
func TestAttachmentStore(t *testing.T) {
t.Parallel()
ctx := context.Background()
ts := NewTestingStore(ctx, t)
_, err := ts.CreateAttachment(ctx, &store.Attachment{
@ -64,6 +65,7 @@ func TestAttachmentStore(t *testing.T) {
}
func TestAttachmentStoreWithFilter(t *testing.T) {
t.Parallel()
ctx := context.Background()
ts := NewTestingStore(ctx, t)
@ -123,6 +125,7 @@ func TestAttachmentStoreWithFilter(t *testing.T) {
}
func TestAttachmentUpdate(t *testing.T) {
t.Parallel()
ctx := context.Background()
ts := NewTestingStore(ctx, t)
@ -153,6 +156,7 @@ func TestAttachmentUpdate(t *testing.T) {
}
func TestAttachmentGetByUID(t *testing.T) {
t.Parallel()
ctx := context.Background()
ts := NewTestingStore(ctx, t)
@ -183,6 +187,7 @@ func TestAttachmentGetByUID(t *testing.T) {
}
func TestAttachmentListWithPagination(t *testing.T) {
t.Parallel()
ctx := context.Background()
ts := NewTestingStore(ctx, t)
@ -222,6 +227,7 @@ func TestAttachmentListWithPagination(t *testing.T) {
}
func TestAttachmentInvalidUID(t *testing.T) {
t.Parallel()
ctx := context.Background()
ts := NewTestingStore(ctx, t)

View File

@ -4,16 +4,19 @@ import (
"context"
"database/sql"
"fmt"
"os"
"strings"
"sync"
"sync/atomic"
"testing"
"time"
"github.com/docker/docker/api/types/container"
"github.com/pkg/errors"
"github.com/testcontainers/testcontainers-go"
"github.com/testcontainers/testcontainers-go/modules/mysql"
"github.com/testcontainers/testcontainers-go/modules/postgres"
"github.com/testcontainers/testcontainers-go/network"
"github.com/testcontainers/testcontainers-go/wait"
// Database drivers for connection verification.
@ -24,23 +27,51 @@ import (
const (
testUser = "root"
testPassword = "test"
// Memos container settings for migration testing.
MemosDockerImage = "neosmemo/memos"
StableMemosVersion = "stable" // Always points to the latest stable release
)
var (
mysqlContainer *mysql.MySQLContainer
postgresContainer *postgres.PostgresContainer
mysqlContainer atomic.Pointer[mysql.MySQLContainer]
postgresContainer atomic.Pointer[postgres.PostgresContainer]
mysqlOnce sync.Once
postgresOnce sync.Once
mysqlBaseDSN string
postgresBaseDSN string
mysqlBaseDSN atomic.Value // stores string
postgresBaseDSN atomic.Value // stores string
dbCounter atomic.Int64
dbCreationMutex sync.Mutex // Protects database creation operations
// Network for container communication.
testDockerNetwork atomic.Pointer[testcontainers.DockerNetwork]
testNetworkOnce sync.Once
)
// getTestNetwork creates or returns the shared Docker network for container communication.
func getTestNetwork(ctx context.Context) (*testcontainers.DockerNetwork, error) {
var networkErr error
testNetworkOnce.Do(func() {
nw, err := network.New(ctx, network.WithDriver("bridge"))
if err != nil {
networkErr = err
return
}
testDockerNetwork.Store(nw)
})
return testDockerNetwork.Load(), networkErr
}
// GetMySQLDSN starts a MySQL container (if not already running) and creates a fresh database for this test.
func GetMySQLDSN(t *testing.T) string {
ctx := context.Background()
mysqlOnce.Do(func() {
nw, err := getTestNetwork(ctx)
if err != nil {
t.Fatalf("failed to create test network: %v", err)
}
container, err := mysql.Run(ctx,
"mysql:8",
mysql.WithDatabase("init_db"),
@ -55,11 +86,12 @@ func GetMySQLDSN(t *testing.T) string {
wait.ForListeningPort("3306/tcp"),
).WithDeadline(120*time.Second),
),
network.WithNetwork(nil, nw),
)
if err != nil {
t.Fatalf("failed to start MySQL container: %v", err)
}
mysqlContainer = container
mysqlContainer.Store(container)
dsn, err := container.ConnectionString(ctx, "multiStatements=true")
if err != nil {
@ -70,16 +102,21 @@ func GetMySQLDSN(t *testing.T) string {
t.Fatalf("MySQL not ready for connections: %v", err)
}
mysqlBaseDSN = dsn
mysqlBaseDSN.Store(dsn)
})
if mysqlBaseDSN == "" {
dsn, ok := mysqlBaseDSN.Load().(string)
if !ok || dsn == "" {
t.Fatal("MySQL container failed to start in a previous test")
}
// Serialize database creation to avoid "table already exists" race conditions
dbCreationMutex.Lock()
defer dbCreationMutex.Unlock()
// Create a fresh database for this test
dbName := fmt.Sprintf("memos_test_%d", dbCounter.Add(1))
db, err := sql.Open("mysql", mysqlBaseDSN)
db, err := sql.Open("mysql", dsn)
if err != nil {
t.Fatalf("failed to connect to MySQL: %v", err)
}
@ -90,7 +127,7 @@ func GetMySQLDSN(t *testing.T) string {
}
// Return DSN pointing to the new database
return strings.Replace(mysqlBaseDSN, "/init_db?", "/"+dbName+"?", 1)
return strings.Replace(dsn, "/init_db?", "/"+dbName+"?", 1)
}
// waitForDB polls the database until it's ready or timeout is reached.
@ -130,6 +167,11 @@ func GetPostgresDSN(t *testing.T) string {
ctx := context.Background()
postgresOnce.Do(func() {
nw, err := getTestNetwork(ctx)
if err != nil {
t.Fatalf("failed to create test network: %v", err)
}
container, err := postgres.Run(ctx,
"postgres:18",
postgres.WithDatabase("init_db"),
@ -141,11 +183,12 @@ func GetPostgresDSN(t *testing.T) string {
wait.ForListeningPort("5432/tcp"),
).WithDeadline(120*time.Second),
),
network.WithNetwork(nil, nw),
)
if err != nil {
t.Fatalf("failed to start PostgreSQL container: %v", err)
}
postgresContainer = container
postgresContainer.Store(container)
dsn, err := container.ConnectionString(ctx, "sslmode=disable")
if err != nil {
@ -156,16 +199,21 @@ func GetPostgresDSN(t *testing.T) string {
t.Fatalf("PostgreSQL not ready for connections: %v", err)
}
postgresBaseDSN = dsn
postgresBaseDSN.Store(dsn)
})
if postgresBaseDSN == "" {
dsn, ok := postgresBaseDSN.Load().(string)
if !ok || dsn == "" {
t.Fatal("PostgreSQL container failed to start in a previous test")
}
// Serialize database creation to avoid "table already exists" race conditions
dbCreationMutex.Lock()
defer dbCreationMutex.Unlock()
// Create a fresh database for this test
dbName := fmt.Sprintf("memos_test_%d", dbCounter.Add(1))
db, err := sql.Open("postgres", postgresBaseDSN)
db, err := sql.Open("postgres", dsn)
if err != nil {
t.Fatalf("failed to connect to PostgreSQL: %v", err)
}
@ -176,17 +224,94 @@ func GetPostgresDSN(t *testing.T) string {
}
// Return DSN pointing to the new database
return strings.Replace(postgresBaseDSN, "/init_db?", "/"+dbName+"?", 1)
return strings.Replace(dsn, "/init_db?", "/"+dbName+"?", 1)
}
// TerminateContainers cleans up all running containers.
// TerminateContainers cleans up all running containers and network.
// This is typically called from TestMain.
func TerminateContainers() {
ctx := context.Background()
if mysqlContainer != nil {
_ = mysqlContainer.Terminate(ctx)
if container := mysqlContainer.Load(); container != nil {
_ = container.Terminate(ctx)
}
if postgresContainer != nil {
_ = postgresContainer.Terminate(ctx)
if container := postgresContainer.Load(); container != nil {
_ = container.Terminate(ctx)
}
if network := testDockerNetwork.Load(); network != nil {
_ = network.Remove(ctx)
}
}
// MemosContainerConfig holds configuration for starting a Memos container.
type MemosContainerConfig struct {
Version string // Memos version tag (e.g., "0.24.0")
Driver string // Database driver: sqlite, mysql, postgres
DSN string // Database DSN (for mysql/postgres)
DataDir string // Host directory to mount for SQLite data
}
// MemosStartupWaitStrategy defines the wait strategy for Memos container startup.
// Uses regex to match various log message formats across versions.
var MemosStartupWaitStrategy = wait.ForAll(
wait.ForLog("(started successfully|has been started on port)").AsRegexp(),
wait.ForListeningPort("5230/tcp"),
).WithDeadline(180 * time.Second)
// StartMemosContainer starts a Memos container for migration testing.
// For SQLite, it mounts the dataDir to /var/opt/memos.
func StartMemosContainer(ctx context.Context, cfg MemosContainerConfig) (testcontainers.Container, error) {
env := map[string]string{
"MEMOS_MODE": "prod",
}
var opts []testcontainers.ContainerCustomizer
switch cfg.Driver {
case "sqlite":
env["MEMOS_DRIVER"] = "sqlite"
opts = append(opts, testcontainers.WithHostConfigModifier(func(hc *container.HostConfig) {
hc.Binds = append(hc.Binds, fmt.Sprintf("%s:%s", cfg.DataDir, "/var/opt/memos"))
}))
default:
return nil, errors.Errorf("unsupported driver for migration testing: %s", cfg.Driver)
}
req := testcontainers.ContainerRequest{
Image: fmt.Sprintf("%s:%s", MemosDockerImage, cfg.Version),
Env: env,
ExposedPorts: []string{"5230/tcp"},
WaitingFor: MemosStartupWaitStrategy,
User: fmt.Sprintf("%d:%d", os.Getuid(), os.Getgid()),
}
// Use local image if specified
if cfg.Version == "local" {
if os.Getenv("MEMOS_TEST_IMAGE_BUILT") == "1" {
req.Image = "memos-test:local"
} else {
req.FromDockerfile = testcontainers.FromDockerfile{
Context: "../../",
Dockerfile: "Dockerfile",
}
}
}
genericReq := testcontainers.GenericContainerRequest{
ContainerRequest: req,
Started: true,
}
// Apply options
for _, opt := range opts {
if err := opt.Customize(&genericReq); err != nil {
return nil, errors.Wrap(err, "failed to apply container option")
}
}
ctr, err := testcontainers.GenericContainer(ctx, genericReq)
if err != nil {
return nil, errors.Wrap(err, "failed to start memos container")
}
return ctr, nil
}

View File

@ -11,6 +11,7 @@ import (
)
func TestIdentityProviderStore(t *testing.T) {
t.Parallel()
ctx := context.Background()
ts := NewTestingStore(ctx, t)
createdIDP, err := ts.CreateIdentityProvider(ctx, &storepb.IdentityProvider{
@ -58,3 +59,396 @@ func TestIdentityProviderStore(t *testing.T) {
require.Equal(t, 0, len(idpList))
ts.Close()
}
func TestIdentityProviderGetByID(t *testing.T) {
t.Parallel()
ctx := context.Background()
ts := NewTestingStore(ctx, t)
// Create IDP
idp, err := ts.CreateIdentityProvider(ctx, createTestOAuth2IDP("Test IDP"))
require.NoError(t, err)
// Get by ID
found, err := ts.GetIdentityProvider(ctx, &store.FindIdentityProvider{ID: &idp.Id})
require.NoError(t, err)
require.NotNil(t, found)
require.Equal(t, idp.Id, found.Id)
require.Equal(t, idp.Name, found.Name)
// Get by non-existent ID
nonExistentID := int32(99999)
notFound, err := ts.GetIdentityProvider(ctx, &store.FindIdentityProvider{ID: &nonExistentID})
require.NoError(t, err)
require.Nil(t, notFound)
ts.Close()
}
func TestIdentityProviderListMultiple(t *testing.T) {
t.Parallel()
ctx := context.Background()
ts := NewTestingStore(ctx, t)
// Create multiple IDPs
_, err := ts.CreateIdentityProvider(ctx, createTestOAuth2IDP("GitHub OAuth"))
require.NoError(t, err)
_, err = ts.CreateIdentityProvider(ctx, createTestOAuth2IDP("Google OAuth"))
require.NoError(t, err)
_, err = ts.CreateIdentityProvider(ctx, createTestOAuth2IDP("GitLab OAuth"))
require.NoError(t, err)
// List all
idpList, err := ts.ListIdentityProviders(ctx, &store.FindIdentityProvider{})
require.NoError(t, err)
require.Len(t, idpList, 3)
ts.Close()
}
func TestIdentityProviderListByID(t *testing.T) {
t.Parallel()
ctx := context.Background()
ts := NewTestingStore(ctx, t)
// Create multiple IDPs
idp1, err := ts.CreateIdentityProvider(ctx, createTestOAuth2IDP("GitHub OAuth"))
require.NoError(t, err)
_, err = ts.CreateIdentityProvider(ctx, createTestOAuth2IDP("Google OAuth"))
require.NoError(t, err)
// List by specific ID
idpList, err := ts.ListIdentityProviders(ctx, &store.FindIdentityProvider{ID: &idp1.Id})
require.NoError(t, err)
require.Len(t, idpList, 1)
require.Equal(t, "GitHub OAuth", idpList[0].Name)
ts.Close()
}
func TestIdentityProviderUpdateName(t *testing.T) {
t.Parallel()
ctx := context.Background()
ts := NewTestingStore(ctx, t)
idp, err := ts.CreateIdentityProvider(ctx, createTestOAuth2IDP("Original Name"))
require.NoError(t, err)
require.Equal(t, "Original Name", idp.Name)
// Update name
newName := "Updated Name"
updated, err := ts.UpdateIdentityProvider(ctx, &store.UpdateIdentityProviderV1{
ID: idp.Id,
Type: storepb.IdentityProvider_OAUTH2,
Name: &newName,
})
require.NoError(t, err)
require.Equal(t, "Updated Name", updated.Name)
// Verify update persisted
found, err := ts.GetIdentityProvider(ctx, &store.FindIdentityProvider{ID: &idp.Id})
require.NoError(t, err)
require.Equal(t, "Updated Name", found.Name)
ts.Close()
}
func TestIdentityProviderUpdateIdentifierFilter(t *testing.T) {
t.Parallel()
ctx := context.Background()
ts := NewTestingStore(ctx, t)
idp, err := ts.CreateIdentityProvider(ctx, createTestOAuth2IDP("Test IDP"))
require.NoError(t, err)
require.Equal(t, "", idp.IdentifierFilter)
// Update identifier filter
newFilter := "@example.com$"
updated, err := ts.UpdateIdentityProvider(ctx, &store.UpdateIdentityProviderV1{
ID: idp.Id,
Type: storepb.IdentityProvider_OAUTH2,
IdentifierFilter: &newFilter,
})
require.NoError(t, err)
require.Equal(t, "@example.com$", updated.IdentifierFilter)
// Verify update persisted
found, err := ts.GetIdentityProvider(ctx, &store.FindIdentityProvider{ID: &idp.Id})
require.NoError(t, err)
require.Equal(t, "@example.com$", found.IdentifierFilter)
ts.Close()
}
func TestIdentityProviderUpdateConfig(t *testing.T) {
t.Parallel()
ctx := context.Background()
ts := NewTestingStore(ctx, t)
idp, err := ts.CreateIdentityProvider(ctx, createTestOAuth2IDP("Test IDP"))
require.NoError(t, err)
// Update config
newConfig := &storepb.IdentityProviderConfig{
Config: &storepb.IdentityProviderConfig_Oauth2Config{
Oauth2Config: &storepb.OAuth2Config{
ClientId: "new_client_id",
ClientSecret: "new_client_secret",
AuthUrl: "https://newprovider.com/auth",
TokenUrl: "https://newprovider.com/token",
UserInfoUrl: "https://newprovider.com/user",
Scopes: []string{"openid", "profile", "email"},
FieldMapping: &storepb.FieldMapping{
Identifier: "sub",
DisplayName: "name",
Email: "email",
},
},
},
}
updated, err := ts.UpdateIdentityProvider(ctx, &store.UpdateIdentityProviderV1{
ID: idp.Id,
Type: storepb.IdentityProvider_OAUTH2,
Config: newConfig,
})
require.NoError(t, err)
require.Equal(t, "new_client_id", updated.Config.GetOauth2Config().ClientId)
require.Equal(t, "new_client_secret", updated.Config.GetOauth2Config().ClientSecret)
require.Contains(t, updated.Config.GetOauth2Config().Scopes, "openid")
// Verify update persisted
found, err := ts.GetIdentityProvider(ctx, &store.FindIdentityProvider{ID: &idp.Id})
require.NoError(t, err)
require.Equal(t, "new_client_id", found.Config.GetOauth2Config().ClientId)
ts.Close()
}
func TestIdentityProviderUpdateMultipleFields(t *testing.T) {
t.Parallel()
ctx := context.Background()
ts := NewTestingStore(ctx, t)
idp, err := ts.CreateIdentityProvider(ctx, createTestOAuth2IDP("Original"))
require.NoError(t, err)
// Update multiple fields at once
newName := "Updated IDP"
newFilter := "^admin@"
updated, err := ts.UpdateIdentityProvider(ctx, &store.UpdateIdentityProviderV1{
ID: idp.Id,
Type: storepb.IdentityProvider_OAUTH2,
Name: &newName,
IdentifierFilter: &newFilter,
})
require.NoError(t, err)
require.Equal(t, "Updated IDP", updated.Name)
require.Equal(t, "^admin@", updated.IdentifierFilter)
ts.Close()
}
func TestIdentityProviderDelete(t *testing.T) {
t.Parallel()
ctx := context.Background()
ts := NewTestingStore(ctx, t)
idp, err := ts.CreateIdentityProvider(ctx, createTestOAuth2IDP("Test IDP"))
require.NoError(t, err)
// Delete
err = ts.DeleteIdentityProvider(ctx, &store.DeleteIdentityProvider{ID: idp.Id})
require.NoError(t, err)
// Verify deletion
found, err := ts.GetIdentityProvider(ctx, &store.FindIdentityProvider{ID: &idp.Id})
require.NoError(t, err)
require.Nil(t, found)
ts.Close()
}
func TestIdentityProviderDeleteNotAffectOthers(t *testing.T) {
t.Parallel()
ctx := context.Background()
ts := NewTestingStore(ctx, t)
// Create multiple IDPs
idp1, err := ts.CreateIdentityProvider(ctx, createTestOAuth2IDP("IDP 1"))
require.NoError(t, err)
idp2, err := ts.CreateIdentityProvider(ctx, createTestOAuth2IDP("IDP 2"))
require.NoError(t, err)
// Delete first one
err = ts.DeleteIdentityProvider(ctx, &store.DeleteIdentityProvider{ID: idp1.Id})
require.NoError(t, err)
// Verify second still exists
found, err := ts.GetIdentityProvider(ctx, &store.FindIdentityProvider{ID: &idp2.Id})
require.NoError(t, err)
require.NotNil(t, found)
require.Equal(t, "IDP 2", found.Name)
// Verify list only contains second
idpList, err := ts.ListIdentityProviders(ctx, &store.FindIdentityProvider{})
require.NoError(t, err)
require.Len(t, idpList, 1)
ts.Close()
}
func TestIdentityProviderOAuth2ConfigScopes(t *testing.T) {
t.Parallel()
ctx := context.Background()
ts := NewTestingStore(ctx, t)
// Create IDP with multiple scopes
idp, err := ts.CreateIdentityProvider(ctx, &storepb.IdentityProvider{
Name: "Multi-Scope OAuth",
Type: storepb.IdentityProvider_OAUTH2,
Config: &storepb.IdentityProviderConfig{
Config: &storepb.IdentityProviderConfig_Oauth2Config{
Oauth2Config: &storepb.OAuth2Config{
ClientId: "client_id",
ClientSecret: "client_secret",
AuthUrl: "https://provider.com/auth",
TokenUrl: "https://provider.com/token",
UserInfoUrl: "https://provider.com/userinfo",
Scopes: []string{"openid", "profile", "email", "groups"},
FieldMapping: &storepb.FieldMapping{
Identifier: "sub",
DisplayName: "name",
Email: "email",
},
},
},
},
})
require.NoError(t, err)
// Verify scopes are preserved
found, err := ts.GetIdentityProvider(ctx, &store.FindIdentityProvider{ID: &idp.Id})
require.NoError(t, err)
require.Len(t, found.Config.GetOauth2Config().Scopes, 4)
require.Contains(t, found.Config.GetOauth2Config().Scopes, "openid")
require.Contains(t, found.Config.GetOauth2Config().Scopes, "groups")
ts.Close()
}
func TestIdentityProviderFieldMapping(t *testing.T) {
t.Parallel()
ctx := context.Background()
ts := NewTestingStore(ctx, t)
// Create IDP with custom field mapping
idp, err := ts.CreateIdentityProvider(ctx, &storepb.IdentityProvider{
Name: "Custom Field Mapping",
Type: storepb.IdentityProvider_OAUTH2,
Config: &storepb.IdentityProviderConfig{
Config: &storepb.IdentityProviderConfig_Oauth2Config{
Oauth2Config: &storepb.OAuth2Config{
ClientId: "client_id",
ClientSecret: "client_secret",
AuthUrl: "https://provider.com/auth",
TokenUrl: "https://provider.com/token",
UserInfoUrl: "https://provider.com/userinfo",
Scopes: []string{"login"},
FieldMapping: &storepb.FieldMapping{
Identifier: "preferred_username",
DisplayName: "full_name",
Email: "email_address",
},
},
},
},
})
require.NoError(t, err)
// Verify field mapping is preserved
found, err := ts.GetIdentityProvider(ctx, &store.FindIdentityProvider{ID: &idp.Id})
require.NoError(t, err)
require.Equal(t, "preferred_username", found.Config.GetOauth2Config().FieldMapping.Identifier)
require.Equal(t, "full_name", found.Config.GetOauth2Config().FieldMapping.DisplayName)
require.Equal(t, "email_address", found.Config.GetOauth2Config().FieldMapping.Email)
ts.Close()
}
func TestIdentityProviderIdentifierFilterPatterns(t *testing.T) {
t.Parallel()
ctx := context.Background()
ts := NewTestingStore(ctx, t)
testCases := []struct {
name string
filter string
}{
{"Domain filter", "@company\\.com$"},
{"Prefix filter", "^admin_"},
{"Complex regex", "^[a-z]+@(dept1|dept2)\\.example\\.com$"},
{"Empty filter", ""},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
idp, err := ts.CreateIdentityProvider(ctx, &storepb.IdentityProvider{
Name: tc.name,
Type: storepb.IdentityProvider_OAUTH2,
IdentifierFilter: tc.filter,
Config: &storepb.IdentityProviderConfig{
Config: &storepb.IdentityProviderConfig_Oauth2Config{
Oauth2Config: &storepb.OAuth2Config{
ClientId: "client_id",
ClientSecret: "client_secret",
AuthUrl: "https://provider.com/auth",
TokenUrl: "https://provider.com/token",
UserInfoUrl: "https://provider.com/userinfo",
Scopes: []string{"login"},
FieldMapping: &storepb.FieldMapping{
Identifier: "sub",
},
},
},
},
})
require.NoError(t, err)
found, err := ts.GetIdentityProvider(ctx, &store.FindIdentityProvider{ID: &idp.Id})
require.NoError(t, err)
require.Equal(t, tc.filter, found.IdentifierFilter)
// Cleanup
err = ts.DeleteIdentityProvider(ctx, &store.DeleteIdentityProvider{ID: idp.Id})
require.NoError(t, err)
})
}
ts.Close()
}
// Helper function to create a test OAuth2 IDP.
func createTestOAuth2IDP(name string) *storepb.IdentityProvider {
return &storepb.IdentityProvider{
Name: name,
Type: storepb.IdentityProvider_OAUTH2,
IdentifierFilter: "",
Config: &storepb.IdentityProviderConfig{
Config: &storepb.IdentityProviderConfig_Oauth2Config{
Oauth2Config: &storepb.OAuth2Config{
ClientId: "client_id",
ClientSecret: "client_secret",
AuthUrl: "https://provider.com/auth",
TokenUrl: "https://provider.com/token",
UserInfoUrl: "https://provider.com/userinfo",
Scopes: []string{"login"},
FieldMapping: &storepb.FieldMapping{
Identifier: "login",
DisplayName: "name",
Email: "email",
},
},
},
},
}
}

View File

@ -11,6 +11,7 @@ import (
)
func TestInboxStore(t *testing.T) {
t.Parallel()
ctx := context.Background()
ts := NewTestingStore(ctx, t)
user, err := createTestingHostUser(ctx, ts)
@ -52,3 +53,540 @@ func TestInboxStore(t *testing.T) {
require.Equal(t, 0, len(inboxes))
ts.Close()
}
func TestInboxListByID(t *testing.T) {
t.Parallel()
ctx := context.Background()
ts := NewTestingStore(ctx, t)
user, err := createTestingHostUser(ctx, ts)
require.NoError(t, err)
inbox, err := ts.CreateInbox(ctx, &store.Inbox{
SenderID: 0,
ReceiverID: user.ID,
Status: store.UNREAD,
Message: &storepb.InboxMessage{Type: storepb.InboxMessage_MEMO_COMMENT},
})
require.NoError(t, err)
// List by ID
inboxes, err := ts.ListInboxes(ctx, &store.FindInbox{ID: &inbox.ID})
require.NoError(t, err)
require.Len(t, inboxes, 1)
require.Equal(t, inbox.ID, inboxes[0].ID)
// List by non-existent ID
nonExistentID := int32(99999)
inboxes, err = ts.ListInboxes(ctx, &store.FindInbox{ID: &nonExistentID})
require.NoError(t, err)
require.Len(t, inboxes, 0)
ts.Close()
}
func TestInboxListBySenderID(t *testing.T) {
t.Parallel()
ctx := context.Background()
ts := NewTestingStore(ctx, t)
user1, err := createTestingHostUser(ctx, ts)
require.NoError(t, err)
user2, err := createTestingUserWithRole(ctx, ts, "user2", store.RoleUser)
require.NoError(t, err)
// Create inbox from system bot (senderID = 0)
_, err = ts.CreateInbox(ctx, &store.Inbox{
SenderID: 0,
ReceiverID: user1.ID,
Status: store.UNREAD,
Message: &storepb.InboxMessage{Type: storepb.InboxMessage_MEMO_COMMENT},
})
require.NoError(t, err)
// Create inbox from user2
_, err = ts.CreateInbox(ctx, &store.Inbox{
SenderID: user2.ID,
ReceiverID: user1.ID,
Status: store.UNREAD,
Message: &storepb.InboxMessage{Type: storepb.InboxMessage_MEMO_COMMENT},
})
require.NoError(t, err)
// List by sender ID = user2
inboxes, err := ts.ListInboxes(ctx, &store.FindInbox{SenderID: &user2.ID})
require.NoError(t, err)
require.Len(t, inboxes, 1)
require.Equal(t, user2.ID, inboxes[0].SenderID)
// List by sender ID = 0 (system bot)
systemBotID := int32(0)
inboxes, err = ts.ListInboxes(ctx, &store.FindInbox{SenderID: &systemBotID})
require.NoError(t, err)
require.Len(t, inboxes, 1)
require.Equal(t, int32(0), inboxes[0].SenderID)
ts.Close()
}
func TestInboxListByStatus(t *testing.T) {
t.Parallel()
ctx := context.Background()
ts := NewTestingStore(ctx, t)
user, err := createTestingHostUser(ctx, ts)
require.NoError(t, err)
// Create UNREAD inbox
_, err = ts.CreateInbox(ctx, &store.Inbox{
SenderID: 0,
ReceiverID: user.ID,
Status: store.UNREAD,
Message: &storepb.InboxMessage{Type: storepb.InboxMessage_MEMO_COMMENT},
})
require.NoError(t, err)
// Create another inbox and archive it
inbox2, err := ts.CreateInbox(ctx, &store.Inbox{
SenderID: 0,
ReceiverID: user.ID,
Status: store.UNREAD,
Message: &storepb.InboxMessage{Type: storepb.InboxMessage_MEMO_COMMENT},
})
require.NoError(t, err)
_, err = ts.UpdateInbox(ctx, &store.UpdateInbox{ID: inbox2.ID, Status: store.ARCHIVED})
require.NoError(t, err)
// List by UNREAD status
unreadStatus := store.UNREAD
inboxes, err := ts.ListInboxes(ctx, &store.FindInbox{Status: &unreadStatus})
require.NoError(t, err)
require.Len(t, inboxes, 1)
require.Equal(t, store.UNREAD, inboxes[0].Status)
// List by ARCHIVED status
archivedStatus := store.ARCHIVED
inboxes, err = ts.ListInboxes(ctx, &store.FindInbox{Status: &archivedStatus})
require.NoError(t, err)
require.Len(t, inboxes, 1)
require.Equal(t, store.ARCHIVED, inboxes[0].Status)
ts.Close()
}
func TestInboxListByMessageType(t *testing.T) {
t.Parallel()
ctx := context.Background()
ts := NewTestingStore(ctx, t)
user, err := createTestingHostUser(ctx, ts)
require.NoError(t, err)
// Create MEMO_COMMENT inboxes
_, err = ts.CreateInbox(ctx, &store.Inbox{
SenderID: 0,
ReceiverID: user.ID,
Status: store.UNREAD,
Message: &storepb.InboxMessage{Type: storepb.InboxMessage_MEMO_COMMENT},
})
require.NoError(t, err)
_, err = ts.CreateInbox(ctx, &store.Inbox{
SenderID: 0,
ReceiverID: user.ID,
Status: store.UNREAD,
Message: &storepb.InboxMessage{Type: storepb.InboxMessage_MEMO_COMMENT},
})
require.NoError(t, err)
// List by MEMO_COMMENT type
memoCommentType := storepb.InboxMessage_MEMO_COMMENT
inboxes, err := ts.ListInboxes(ctx, &store.FindInbox{MessageType: &memoCommentType})
require.NoError(t, err)
require.Len(t, inboxes, 2)
for _, inbox := range inboxes {
require.Equal(t, storepb.InboxMessage_MEMO_COMMENT, inbox.Message.Type)
}
ts.Close()
}
func TestInboxListPagination(t *testing.T) {
t.Parallel()
ctx := context.Background()
ts := NewTestingStore(ctx, t)
user, err := createTestingHostUser(ctx, ts)
require.NoError(t, err)
// Create 5 inboxes
for i := 0; i < 5; i++ {
_, err = ts.CreateInbox(ctx, &store.Inbox{
SenderID: 0,
ReceiverID: user.ID,
Status: store.UNREAD,
Message: &storepb.InboxMessage{Type: storepb.InboxMessage_MEMO_COMMENT},
})
require.NoError(t, err)
}
// Test Limit only
limit := 3
inboxes, err := ts.ListInboxes(ctx, &store.FindInbox{
ReceiverID: &user.ID,
Limit: &limit,
})
require.NoError(t, err)
require.Len(t, inboxes, 3)
// Test Limit + Offset (offset requires limit in the implementation)
limit = 2
offset := 2
inboxes, err = ts.ListInboxes(ctx, &store.FindInbox{
ReceiverID: &user.ID,
Limit: &limit,
Offset: &offset,
})
require.NoError(t, err)
require.Len(t, inboxes, 2)
// Test Limit + Offset skipping to end
limit = 10
offset = 3
inboxes, err = ts.ListInboxes(ctx, &store.FindInbox{
ReceiverID: &user.ID,
Limit: &limit,
Offset: &offset,
})
require.NoError(t, err)
require.Len(t, inboxes, 2) // Only 2 remaining after offset of 3
ts.Close()
}
func TestInboxListCombinedFilters(t *testing.T) {
t.Parallel()
ctx := context.Background()
ts := NewTestingStore(ctx, t)
user1, err := createTestingHostUser(ctx, ts)
require.NoError(t, err)
user2, err := createTestingUserWithRole(ctx, ts, "user2", store.RoleUser)
require.NoError(t, err)
// Create various inboxes
// user2 -> user1, MEMO_COMMENT, UNREAD
_, err = ts.CreateInbox(ctx, &store.Inbox{
SenderID: user2.ID,
ReceiverID: user1.ID,
Status: store.UNREAD,
Message: &storepb.InboxMessage{Type: storepb.InboxMessage_MEMO_COMMENT},
})
require.NoError(t, err)
// user2 -> user1, TYPE_UNSPECIFIED, UNREAD
_, err = ts.CreateInbox(ctx, &store.Inbox{
SenderID: user2.ID,
ReceiverID: user1.ID,
Status: store.UNREAD,
Message: &storepb.InboxMessage{Type: storepb.InboxMessage_TYPE_UNSPECIFIED},
})
require.NoError(t, err)
// system -> user1, MEMO_COMMENT, ARCHIVED
inbox3, err := ts.CreateInbox(ctx, &store.Inbox{
SenderID: 0,
ReceiverID: user1.ID,
Status: store.UNREAD,
Message: &storepb.InboxMessage{Type: storepb.InboxMessage_MEMO_COMMENT},
})
require.NoError(t, err)
_, err = ts.UpdateInbox(ctx, &store.UpdateInbox{ID: inbox3.ID, Status: store.ARCHIVED})
require.NoError(t, err)
// Combined filter: ReceiverID + SenderID + Status
unreadStatus := store.UNREAD
inboxes, err := ts.ListInboxes(ctx, &store.FindInbox{
ReceiverID: &user1.ID,
SenderID: &user2.ID,
Status: &unreadStatus,
})
require.NoError(t, err)
require.Len(t, inboxes, 2)
// Combined filter: ReceiverID + MessageType + Status
memoCommentType := storepb.InboxMessage_MEMO_COMMENT
inboxes, err = ts.ListInboxes(ctx, &store.FindInbox{
ReceiverID: &user1.ID,
MessageType: &memoCommentType,
Status: &unreadStatus,
})
require.NoError(t, err)
require.Len(t, inboxes, 1)
require.Equal(t, user2.ID, inboxes[0].SenderID)
ts.Close()
}
func TestInboxMessagePayload(t *testing.T) {
t.Parallel()
ctx := context.Background()
ts := NewTestingStore(ctx, t)
user, err := createTestingHostUser(ctx, ts)
require.NoError(t, err)
// Create inbox with message payload containing activity ID
activityID := int32(123)
inbox, err := ts.CreateInbox(ctx, &store.Inbox{
SenderID: 0,
ReceiverID: user.ID,
Status: store.UNREAD,
Message: &storepb.InboxMessage{
Type: storepb.InboxMessage_MEMO_COMMENT,
ActivityId: &activityID,
},
})
require.NoError(t, err)
require.NotNil(t, inbox.Message)
require.Equal(t, storepb.InboxMessage_MEMO_COMMENT, inbox.Message.Type)
require.Equal(t, activityID, *inbox.Message.ActivityId)
// List and verify payload is preserved
inboxes, err := ts.ListInboxes(ctx, &store.FindInbox{ReceiverID: &user.ID})
require.NoError(t, err)
require.Len(t, inboxes, 1)
require.Equal(t, activityID, *inboxes[0].Message.ActivityId)
ts.Close()
}
func TestInboxUpdateStatus(t *testing.T) {
t.Parallel()
ctx := context.Background()
ts := NewTestingStore(ctx, t)
user, err := createTestingHostUser(ctx, ts)
require.NoError(t, err)
inbox, err := ts.CreateInbox(ctx, &store.Inbox{
SenderID: 0,
ReceiverID: user.ID,
Status: store.UNREAD,
Message: &storepb.InboxMessage{Type: storepb.InboxMessage_MEMO_COMMENT},
})
require.NoError(t, err)
require.Equal(t, store.UNREAD, inbox.Status)
// Update to ARCHIVED
updated, err := ts.UpdateInbox(ctx, &store.UpdateInbox{
ID: inbox.ID,
Status: store.ARCHIVED,
})
require.NoError(t, err)
require.Equal(t, store.ARCHIVED, updated.Status)
require.Equal(t, inbox.ID, updated.ID)
// Verify the update persisted
inboxes, err := ts.ListInboxes(ctx, &store.FindInbox{ID: &inbox.ID})
require.NoError(t, err)
require.Len(t, inboxes, 1)
require.Equal(t, store.ARCHIVED, inboxes[0].Status)
ts.Close()
}
func TestInboxListByMessageTypeMultipleTypes(t *testing.T) {
t.Parallel()
ctx := context.Background()
ts := NewTestingStore(ctx, t)
user, err := createTestingHostUser(ctx, ts)
require.NoError(t, err)
// Create inboxes with different message types
_, err = ts.CreateInbox(ctx, &store.Inbox{
SenderID: 0,
ReceiverID: user.ID,
Status: store.UNREAD,
Message: &storepb.InboxMessage{Type: storepb.InboxMessage_MEMO_COMMENT},
})
require.NoError(t, err)
_, err = ts.CreateInbox(ctx, &store.Inbox{
SenderID: 0,
ReceiverID: user.ID,
Status: store.UNREAD,
Message: &storepb.InboxMessage{Type: storepb.InboxMessage_TYPE_UNSPECIFIED},
})
require.NoError(t, err)
_, err = ts.CreateInbox(ctx, &store.Inbox{
SenderID: 0,
ReceiverID: user.ID,
Status: store.UNREAD,
Message: &storepb.InboxMessage{Type: storepb.InboxMessage_MEMO_COMMENT},
})
require.NoError(t, err)
// Filter by MEMO_COMMENT - should get 2
memoCommentType := storepb.InboxMessage_MEMO_COMMENT
inboxes, err := ts.ListInboxes(ctx, &store.FindInbox{
ReceiverID: &user.ID,
MessageType: &memoCommentType,
})
require.NoError(t, err)
require.Len(t, inboxes, 2)
for _, inbox := range inboxes {
require.Equal(t, storepb.InboxMessage_MEMO_COMMENT, inbox.Message.Type)
}
// Filter by TYPE_UNSPECIFIED - should get 1
unspecifiedType := storepb.InboxMessage_TYPE_UNSPECIFIED
inboxes, err = ts.ListInboxes(ctx, &store.FindInbox{
ReceiverID: &user.ID,
MessageType: &unspecifiedType,
})
require.NoError(t, err)
require.Len(t, inboxes, 1)
require.Equal(t, storepb.InboxMessage_TYPE_UNSPECIFIED, inboxes[0].Message.Type)
ts.Close()
}
func TestInboxMessageTypeFilterWithPayload(t *testing.T) {
t.Parallel()
ctx := context.Background()
ts := NewTestingStore(ctx, t)
user, err := createTestingHostUser(ctx, ts)
require.NoError(t, err)
// Create inbox with full payload
activityID := int32(456)
_, err = ts.CreateInbox(ctx, &store.Inbox{
SenderID: 0,
ReceiverID: user.ID,
Status: store.UNREAD,
Message: &storepb.InboxMessage{
Type: storepb.InboxMessage_MEMO_COMMENT,
ActivityId: &activityID,
},
})
require.NoError(t, err)
// Create inbox with different type but also has payload
otherActivityID := int32(789)
_, err = ts.CreateInbox(ctx, &store.Inbox{
SenderID: 0,
ReceiverID: user.ID,
Status: store.UNREAD,
Message: &storepb.InboxMessage{
Type: storepb.InboxMessage_TYPE_UNSPECIFIED,
ActivityId: &otherActivityID,
},
})
require.NoError(t, err)
// Filter by type should work correctly even with complex JSON payload
memoCommentType := storepb.InboxMessage_MEMO_COMMENT
inboxes, err := ts.ListInboxes(ctx, &store.FindInbox{
ReceiverID: &user.ID,
MessageType: &memoCommentType,
})
require.NoError(t, err)
require.Len(t, inboxes, 1)
require.Equal(t, activityID, *inboxes[0].Message.ActivityId)
ts.Close()
}
func TestInboxMessageTypeFilterWithStatusAndPagination(t *testing.T) {
t.Parallel()
ctx := context.Background()
ts := NewTestingStore(ctx, t)
user, err := createTestingHostUser(ctx, ts)
require.NoError(t, err)
// Create multiple inboxes with various combinations
for i := 0; i < 5; i++ {
_, err = ts.CreateInbox(ctx, &store.Inbox{
SenderID: 0,
ReceiverID: user.ID,
Status: store.UNREAD,
Message: &storepb.InboxMessage{Type: storepb.InboxMessage_MEMO_COMMENT},
})
require.NoError(t, err)
}
// Archive 2 of them
allInboxes, err := ts.ListInboxes(ctx, &store.FindInbox{ReceiverID: &user.ID})
require.NoError(t, err)
for i := 0; i < 2; i++ {
_, err = ts.UpdateInbox(ctx, &store.UpdateInbox{ID: allInboxes[i].ID, Status: store.ARCHIVED})
require.NoError(t, err)
}
// Filter by type + status + pagination
memoCommentType := storepb.InboxMessage_MEMO_COMMENT
unreadStatus := store.UNREAD
limit := 2
inboxes, err := ts.ListInboxes(ctx, &store.FindInbox{
ReceiverID: &user.ID,
MessageType: &memoCommentType,
Status: &unreadStatus,
Limit: &limit,
})
require.NoError(t, err)
require.Len(t, inboxes, 2)
for _, inbox := range inboxes {
require.Equal(t, storepb.InboxMessage_MEMO_COMMENT, inbox.Message.Type)
require.Equal(t, store.UNREAD, inbox.Status)
}
// Get next page
offset := 2
inboxes, err = ts.ListInboxes(ctx, &store.FindInbox{
ReceiverID: &user.ID,
MessageType: &memoCommentType,
Status: &unreadStatus,
Limit: &limit,
Offset: &offset,
})
require.NoError(t, err)
require.Len(t, inboxes, 1) // Only 1 remaining (3 unread total, got 2, now 1 left)
ts.Close()
}
func TestInboxMultipleReceivers(t *testing.T) {
t.Parallel()
ctx := context.Background()
ts := NewTestingStore(ctx, t)
user1, err := createTestingHostUser(ctx, ts)
require.NoError(t, err)
user2, err := createTestingUserWithRole(ctx, ts, "user2", store.RoleUser)
require.NoError(t, err)
// Create inbox for user1
_, err = ts.CreateInbox(ctx, &store.Inbox{
SenderID: 0,
ReceiverID: user1.ID,
Status: store.UNREAD,
Message: &storepb.InboxMessage{Type: storepb.InboxMessage_MEMO_COMMENT},
})
require.NoError(t, err)
// Create inbox for user2
_, err = ts.CreateInbox(ctx, &store.Inbox{
SenderID: 0,
ReceiverID: user2.ID,
Status: store.UNREAD,
Message: &storepb.InboxMessage{Type: storepb.InboxMessage_MEMO_COMMENT},
})
require.NoError(t, err)
// User1 should only see their inbox
inboxes, err := ts.ListInboxes(ctx, &store.FindInbox{ReceiverID: &user1.ID})
require.NoError(t, err)
require.Len(t, inboxes, 1)
require.Equal(t, user1.ID, inboxes[0].ReceiverID)
// User2 should only see their inbox
inboxes, err = ts.ListInboxes(ctx, &store.FindInbox{ReceiverID: &user2.ID})
require.NoError(t, err)
require.Len(t, inboxes, 1)
require.Equal(t, user2.ID, inboxes[0].ReceiverID)
ts.Close()
}

View File

@ -11,6 +11,7 @@ import (
)
func TestInstanceSettingV1Store(t *testing.T) {
t.Parallel()
ctx := context.Background()
ts := NewTestingStore(ctx, t)
instanceSetting, err := ts.UpsertInstanceSetting(ctx, &storepb.InstanceSetting{
@ -31,6 +32,7 @@ func TestInstanceSettingV1Store(t *testing.T) {
}
func TestInstanceSettingGetNonExistent(t *testing.T) {
t.Parallel()
ctx := context.Background()
ts := NewTestingStore(ctx, t)
@ -45,6 +47,7 @@ func TestInstanceSettingGetNonExistent(t *testing.T) {
}
func TestInstanceSettingUpsertUpdate(t *testing.T) {
t.Parallel()
ctx := context.Background()
ts := NewTestingStore(ctx, t)
@ -88,6 +91,7 @@ func TestInstanceSettingUpsertUpdate(t *testing.T) {
}
func TestInstanceSettingBasicSetting(t *testing.T) {
t.Parallel()
ctx := context.Background()
ts := NewTestingStore(ctx, t)
@ -116,6 +120,7 @@ func TestInstanceSettingBasicSetting(t *testing.T) {
}
func TestInstanceSettingGeneralSetting(t *testing.T) {
t.Parallel()
ctx := context.Background()
ts := NewTestingStore(ctx, t)
@ -146,6 +151,7 @@ func TestInstanceSettingGeneralSetting(t *testing.T) {
}
func TestInstanceSettingMemoRelatedSetting(t *testing.T) {
t.Parallel()
ctx := context.Background()
ts := NewTestingStore(ctx, t)
@ -179,6 +185,7 @@ func TestInstanceSettingMemoRelatedSetting(t *testing.T) {
}
func TestInstanceSettingStorageSetting(t *testing.T) {
t.Parallel()
ctx := context.Background()
ts := NewTestingStore(ctx, t)
@ -214,6 +221,7 @@ func TestInstanceSettingStorageSetting(t *testing.T) {
}
func TestInstanceSettingListAll(t *testing.T) {
t.Parallel()
ctx := context.Background()
ts := NewTestingStore(ctx, t)
@ -246,3 +254,47 @@ func TestInstanceSettingListAll(t *testing.T) {
ts.Close()
}
func TestInstanceSettingEdgeCases(t *testing.T) {
t.Parallel()
ctx := context.Background()
ts := NewTestingStore(ctx, t)
// Case 1: General Setting with special characters and Unicode
specialScript := `<script>alert("你好"); var x = 'test\'s';</script>`
specialStyle := `body { font-family: "Noto Sans SC", sans-serif; content: "\u2764"; }`
_, err := ts.UpsertInstanceSetting(ctx, &storepb.InstanceSetting{
Key: storepb.InstanceSettingKey_GENERAL,
Value: &storepb.InstanceSetting_GeneralSetting{
GeneralSetting: &storepb.InstanceGeneralSetting{
AdditionalScript: specialScript,
AdditionalStyle: specialStyle,
},
},
})
require.NoError(t, err)
generalSetting, err := ts.GetInstanceGeneralSetting(ctx)
require.NoError(t, err)
require.Equal(t, specialScript, generalSetting.AdditionalScript)
require.Equal(t, specialStyle, generalSetting.AdditionalStyle)
// Case 2: Memo Related Setting with Unicode reactions
unicodeReactions := []string{"🐱", "🐶", "🦊", "🦄"}
_, err = ts.UpsertInstanceSetting(ctx, &storepb.InstanceSetting{
Key: storepb.InstanceSettingKey_MEMO_RELATED,
Value: &storepb.InstanceSetting_MemoRelatedSetting{
MemoRelatedSetting: &storepb.InstanceMemoRelatedSetting{
ContentLengthLimit: 1000,
Reactions: unicodeReactions,
},
},
})
require.NoError(t, err)
memoSetting, err := ts.GetInstanceMemoRelatedSetting(ctx)
require.NoError(t, err)
require.Equal(t, unicodeReactions, memoSetting.Reactions)
ts.Close()
}

View File

@ -13,7 +13,7 @@ func TestMain(m *testing.M) {
// If DRIVER is set, run tests for that driver only
if os.Getenv("DRIVER") != "" {
defer TerminateContainers()
m.Run()
m.Run() //nolint:revive // Exit code is handled by test runner
return
}

View File

@ -16,6 +16,7 @@ import (
// =============================================================================
func TestMemoFilterContentContains(t *testing.T) {
t.Parallel()
tc := NewMemoFilterTestContext(t)
defer tc.Close()
@ -39,6 +40,7 @@ func TestMemoFilterContentContains(t *testing.T) {
}
func TestMemoFilterContentSpecialCharacters(t *testing.T) {
t.Parallel()
tc := NewMemoFilterTestContext(t)
defer tc.Close()
@ -49,6 +51,7 @@ func TestMemoFilterContentSpecialCharacters(t *testing.T) {
}
func TestMemoFilterContentUnicode(t *testing.T) {
t.Parallel()
tc := NewMemoFilterTestContext(t)
defer tc.Close()
@ -58,12 +61,35 @@ func TestMemoFilterContentUnicode(t *testing.T) {
require.Len(t, memos, 1)
}
func TestMemoFilterContentCaseSensitivity(t *testing.T) {
t.Parallel()
tc := NewMemoFilterTestContext(t)
defer tc.Close()
tc.CreateMemo(NewMemoBuilder("memo-case", tc.User.ID).Content("MixedCase Content"))
// Exact match
memos := tc.ListWithFilter(`content.contains("MixedCase")`)
require.Len(t, memos, 1)
// Lowercase match (depends on DB collation, usually case-insensitive in default installs but good to verify behavior)
// SQLite default LIKE is case-insensitive for ASCII.
memosLower := tc.ListWithFilter(`content.contains("mixedcase")`)
// We just verify it doesn't crash; strict case sensitivity expectation depends on DB config.
// For standard Memos setup (SQLite), it's often case-insensitive.
// Let's check if we get a result or not to characterize current behavior.
if len(memosLower) > 0 {
require.Equal(t, "MixedCase Content", memosLower[0].Content)
}
}
// =============================================================================
// Visibility Field Tests
// Schema: visibility (string, ==, !=)
// =============================================================================
func TestMemoFilterVisibilityEquals(t *testing.T) {
t.Parallel()
tc := NewMemoFilterTestContext(t)
defer tc.Close()
@ -83,6 +109,7 @@ func TestMemoFilterVisibilityEquals(t *testing.T) {
}
func TestMemoFilterVisibilityNotEquals(t *testing.T) {
t.Parallel()
tc := NewMemoFilterTestContext(t)
defer tc.Close()
@ -95,6 +122,7 @@ func TestMemoFilterVisibilityNotEquals(t *testing.T) {
}
func TestMemoFilterVisibilityInList(t *testing.T) {
t.Parallel()
tc := NewMemoFilterTestContext(t)
defer tc.Close()
@ -112,6 +140,7 @@ func TestMemoFilterVisibilityInList(t *testing.T) {
// =============================================================================
func TestMemoFilterPinnedEquals(t *testing.T) {
t.Parallel()
tc := NewMemoFilterTestContext(t)
defer tc.Close()
@ -131,6 +160,7 @@ func TestMemoFilterPinnedEquals(t *testing.T) {
}
func TestMemoFilterPinnedPredicate(t *testing.T) {
t.Parallel()
tc := NewMemoFilterTestContext(t)
defer tc.Close()
@ -149,6 +179,7 @@ func TestMemoFilterPinnedPredicate(t *testing.T) {
// =============================================================================
func TestMemoFilterCreatorIdEquals(t *testing.T) {
t.Parallel()
tc := NewMemoFilterTestContext(t)
defer tc.Close()
@ -169,6 +200,7 @@ func TestMemoFilterCreatorIdEquals(t *testing.T) {
}
func TestMemoFilterCreatorIdNotEquals(t *testing.T) {
t.Parallel()
tc := NewMemoFilterTestContext(t)
defer tc.Close()
@ -195,6 +227,7 @@ func TestMemoFilterCreatorIdNotEquals(t *testing.T) {
// =============================================================================
func TestMemoFilterTagInList(t *testing.T) {
t.Parallel()
tc := NewMemoFilterTestContext(t)
defer tc.Close()
@ -213,6 +246,7 @@ func TestMemoFilterTagInList(t *testing.T) {
}
func TestMemoFilterElementInTags(t *testing.T) {
t.Parallel()
tc := NewMemoFilterTestContext(t)
defer tc.Close()
@ -229,6 +263,7 @@ func TestMemoFilterElementInTags(t *testing.T) {
}
func TestMemoFilterHierarchicalTags(t *testing.T) {
t.Parallel()
tc := NewMemoFilterTestContext(t)
defer tc.Close()
@ -241,6 +276,7 @@ func TestMemoFilterHierarchicalTags(t *testing.T) {
}
func TestMemoFilterEmptyTags(t *testing.T) {
t.Parallel()
tc := NewMemoFilterTestContext(t)
defer tc.Close()
@ -257,6 +293,7 @@ func TestMemoFilterEmptyTags(t *testing.T) {
// =============================================================================
func TestMemoFilterHasTaskList(t *testing.T) {
t.Parallel()
tc := NewMemoFilterTestContext(t)
defer tc.Close()
@ -279,6 +316,7 @@ func TestMemoFilterHasTaskList(t *testing.T) {
}
func TestMemoFilterHasLink(t *testing.T) {
t.Parallel()
tc := NewMemoFilterTestContext(t)
defer tc.Close()
@ -293,6 +331,7 @@ func TestMemoFilterHasLink(t *testing.T) {
}
func TestMemoFilterHasCode(t *testing.T) {
t.Parallel()
tc := NewMemoFilterTestContext(t)
defer tc.Close()
@ -307,6 +346,7 @@ func TestMemoFilterHasCode(t *testing.T) {
}
func TestMemoFilterHasIncompleteTasks(t *testing.T) {
t.Parallel()
tc := NewMemoFilterTestContext(t)
defer tc.Close()
@ -329,6 +369,7 @@ func TestMemoFilterHasIncompleteTasks(t *testing.T) {
}
func TestMemoFilterCombinedJSONBool(t *testing.T) {
t.Parallel()
tc := NewMemoFilterTestContext(t)
defer tc.Close()
@ -367,6 +408,7 @@ func TestMemoFilterCombinedJSONBool(t *testing.T) {
// =============================================================================
func TestMemoFilterCreatedTsComparison(t *testing.T) {
t.Parallel()
tc := NewMemoFilterTestContext(t)
defer tc.Close()
@ -387,6 +429,7 @@ func TestMemoFilterCreatedTsComparison(t *testing.T) {
}
func TestMemoFilterCreatedTsWithNow(t *testing.T) {
t.Parallel()
tc := NewMemoFilterTestContext(t)
defer tc.Close()
@ -402,6 +445,7 @@ func TestMemoFilterCreatedTsWithNow(t *testing.T) {
}
func TestMemoFilterCreatedTsArithmetic(t *testing.T) {
t.Parallel()
tc := NewMemoFilterTestContext(t)
defer tc.Close()
@ -421,6 +465,7 @@ func TestMemoFilterCreatedTsArithmetic(t *testing.T) {
}
func TestMemoFilterUpdatedTs(t *testing.T) {
t.Parallel()
tc := NewMemoFilterTestContext(t)
defer tc.Close()
@ -444,6 +489,7 @@ func TestMemoFilterUpdatedTs(t *testing.T) {
}
func TestMemoFilterAllComparisonOperators(t *testing.T) {
t.Parallel()
tc := NewMemoFilterTestContext(t)
defer tc.Close()
@ -472,6 +518,7 @@ func TestMemoFilterAllComparisonOperators(t *testing.T) {
// =============================================================================
func TestMemoFilterLogicalAnd(t *testing.T) {
t.Parallel()
tc := NewMemoFilterTestContext(t)
defer tc.Close()
@ -485,6 +532,7 @@ func TestMemoFilterLogicalAnd(t *testing.T) {
}
func TestMemoFilterLogicalOr(t *testing.T) {
t.Parallel()
tc := NewMemoFilterTestContext(t)
defer tc.Close()
@ -497,6 +545,7 @@ func TestMemoFilterLogicalOr(t *testing.T) {
}
func TestMemoFilterLogicalNot(t *testing.T) {
t.Parallel()
tc := NewMemoFilterTestContext(t)
defer tc.Close()
@ -510,6 +559,7 @@ func TestMemoFilterLogicalNot(t *testing.T) {
}
func TestMemoFilterNegatedComparison(t *testing.T) {
t.Parallel()
tc := NewMemoFilterTestContext(t)
defer tc.Close()
@ -522,6 +572,7 @@ func TestMemoFilterNegatedComparison(t *testing.T) {
}
func TestMemoFilterComplexLogical(t *testing.T) {
t.Parallel()
tc := NewMemoFilterTestContext(t)
defer tc.Close()
@ -545,6 +596,244 @@ func TestMemoFilterComplexLogical(t *testing.T) {
// Test: (pinned || tag in ["important"]) && visibility == "PUBLIC"
memos = tc.ListWithFilter(`(pinned || tag in ["important"]) && visibility == "PUBLIC"`)
require.Len(t, memos, 3)
// Test: De Morgan's Law ! (A || B) == !A && !B
// ! (pinned || has_task_list)
tc.CreateMemo(NewMemoBuilder("memo-no-props", tc.User.ID).Content("Nothing special"))
memos = tc.ListWithFilter(`!(pinned || has_task_list)`)
require.Len(t, memos, 2) // Unpinned-tagged + Nothing special (pinned-untagged is pinned)
}
// =============================================================================
// Tag Comprehension Tests (exists macro)
// Schema: tags (list of strings, supports exists/all macros with predicates)
// =============================================================================
func TestMemoFilterTagsExistsStartsWith(t *testing.T) {
t.Parallel()
tc := NewMemoFilterTestContext(t)
defer tc.Close()
// Create memos with different tags
tc.CreateMemo(NewMemoBuilder("memo-archive1", tc.User.ID).
Content("Archived project memo").
Tags("archive/project", "done"))
tc.CreateMemo(NewMemoBuilder("memo-archive2", tc.User.ID).
Content("Archived work memo").
Tags("archive/work", "old"))
tc.CreateMemo(NewMemoBuilder("memo-active", tc.User.ID).
Content("Active project memo").
Tags("project/active", "todo"))
tc.CreateMemo(NewMemoBuilder("memo-homelab", tc.User.ID).
Content("Homelab memo").
Tags("homelab/memos", "tech"))
// Test: tags.exists(t, t.startsWith("archive")) - should match archived memos
memos := tc.ListWithFilter(`tags.exists(t, t.startsWith("archive"))`)
require.Len(t, memos, 2, "Should find 2 archived memos")
for _, memo := range memos {
hasArchiveTag := false
for _, tag := range memo.Payload.Tags {
if len(tag) >= 7 && tag[:7] == "archive" {
hasArchiveTag = true
break
}
}
require.True(t, hasArchiveTag, "Memo should have tag starting with 'archive'")
}
// Test: !tags.exists(t, t.startsWith("archive")) - should match non-archived memos
memos = tc.ListWithFilter(`!tags.exists(t, t.startsWith("archive"))`)
require.Len(t, memos, 2, "Should find 2 non-archived memos")
// Test: tags.exists(t, t.startsWith("project")) - should match project memos
memos = tc.ListWithFilter(`tags.exists(t, t.startsWith("project"))`)
require.Len(t, memos, 1, "Should find 1 project memo")
// Test: tags.exists(t, t.startsWith("homelab")) - should match homelab memos
memos = tc.ListWithFilter(`tags.exists(t, t.startsWith("homelab"))`)
require.Len(t, memos, 1, "Should find 1 homelab memo")
// Test: tags.exists(t, t.startsWith("nonexistent")) - should match nothing
memos = tc.ListWithFilter(`tags.exists(t, t.startsWith("nonexistent"))`)
require.Len(t, memos, 0, "Should find no memos")
}
func TestMemoFilterTagsExistsContains(t *testing.T) {
t.Parallel()
tc := NewMemoFilterTestContext(t)
defer tc.Close()
// Create memos with different tags
tc.CreateMemo(NewMemoBuilder("memo-todo1", tc.User.ID).
Content("Todo task 1").
Tags("project/todo", "urgent"))
tc.CreateMemo(NewMemoBuilder("memo-todo2", tc.User.ID).
Content("Todo task 2").
Tags("work/todo-list", "pending"))
tc.CreateMemo(NewMemoBuilder("memo-done", tc.User.ID).
Content("Done task").
Tags("project/completed", "done"))
// Test: tags.exists(t, t.contains("todo")) - should match todos
memos := tc.ListWithFilter(`tags.exists(t, t.contains("todo"))`)
require.Len(t, memos, 2, "Should find 2 todo memos")
// Test: tags.exists(t, t.contains("done")) - should match done
memos = tc.ListWithFilter(`tags.exists(t, t.contains("done"))`)
require.Len(t, memos, 1, "Should find 1 done memo")
// Test: !tags.exists(t, t.contains("todo")) - should exclude todos
memos = tc.ListWithFilter(`!tags.exists(t, t.contains("todo"))`)
require.Len(t, memos, 1, "Should find 1 non-todo memo")
}
func TestMemoFilterTagsExistsEndsWith(t *testing.T) {
t.Parallel()
tc := NewMemoFilterTestContext(t)
defer tc.Close()
// Create memos with different tag endings
tc.CreateMemo(NewMemoBuilder("memo-bug", tc.User.ID).
Content("Bug report").
Tags("project/bug", "critical"))
tc.CreateMemo(NewMemoBuilder("memo-debug", tc.User.ID).
Content("Debug session").
Tags("work/debug", "dev"))
tc.CreateMemo(NewMemoBuilder("memo-feature", tc.User.ID).
Content("New feature").
Tags("project/feature", "new"))
// Test: tags.exists(t, t.endsWith("bug")) - should match bug-related tags
memos := tc.ListWithFilter(`tags.exists(t, t.endsWith("bug"))`)
require.Len(t, memos, 2, "Should find 2 bug-related memos")
// Test: tags.exists(t, t.endsWith("feature")) - should match feature
memos = tc.ListWithFilter(`tags.exists(t, t.endsWith("feature"))`)
require.Len(t, memos, 1, "Should find 1 feature memo")
// Test: !tags.exists(t, t.endsWith("bug")) - should exclude bug-related
memos = tc.ListWithFilter(`!tags.exists(t, t.endsWith("bug"))`)
require.Len(t, memos, 1, "Should find 1 non-bug memo")
}
func TestMemoFilterTagsExistsCombinedWithOtherFilters(t *testing.T) {
t.Parallel()
tc := NewMemoFilterTestContext(t)
defer tc.Close()
// Create memos with tags and other properties
tc.CreateMemo(NewMemoBuilder("memo-archived-old", tc.User.ID).
Content("Old archived memo").
Tags("archive/old", "done"))
tc.CreateMemo(NewMemoBuilder("memo-archived-recent", tc.User.ID).
Content("Recent archived memo with TODO").
Tags("archive/recent", "done"))
tc.CreateMemo(NewMemoBuilder("memo-active-todo", tc.User.ID).
Content("Active TODO").
Tags("project/active", "todo"))
// Test: Combine tag filter with content filter
memos := tc.ListWithFilter(`tags.exists(t, t.startsWith("archive")) && content.contains("TODO")`)
require.Len(t, memos, 1, "Should find 1 archived memo with TODO in content")
// Test: OR condition with tag filters
memos = tc.ListWithFilter(`tags.exists(t, t.startsWith("archive")) || tags.exists(t, t.contains("todo"))`)
require.Len(t, memos, 3, "Should find all memos (archived or with todo tag)")
// Test: Complex filter - archived but not containing "Recent"
memos = tc.ListWithFilter(`tags.exists(t, t.startsWith("archive")) && !content.contains("Recent")`)
require.Len(t, memos, 1, "Should find 1 old archived memo")
}
func TestMemoFilterTagsExistsEmptyAndNullCases(t *testing.T) {
t.Parallel()
tc := NewMemoFilterTestContext(t)
defer tc.Close()
// Create memo with no tags
tc.CreateMemo(NewMemoBuilder("memo-no-tags", tc.User.ID).
Content("Memo without tags"))
// Create memo with tags
tc.CreateMemo(NewMemoBuilder("memo-with-tags", tc.User.ID).
Content("Memo with tags").
Tags("tag1", "tag2"))
// Test: tags.exists should not match memos without tags
memos := tc.ListWithFilter(`tags.exists(t, t.startsWith("tag"))`)
require.Len(t, memos, 1, "Should only find memo with tags")
// Test: Negation should match memos without matching tags
memos = tc.ListWithFilter(`!tags.exists(t, t.startsWith("tag"))`)
require.Len(t, memos, 1, "Should find memo without matching tags")
}
func TestMemoFilterIssue5480_ArchiveWorkflow(t *testing.T) {
t.Parallel()
tc := NewMemoFilterTestContext(t)
defer tc.Close()
// Create a realistic scenario as described in issue #5480
// User has hierarchical tags and archives memos by prefixing with "archive"
// Active memos
tc.CreateMemo(NewMemoBuilder("memo-homelab", tc.User.ID).
Content("Setting up Memos").
Tags("homelab/memos", "tech"))
tc.CreateMemo(NewMemoBuilder("memo-project-alpha", tc.User.ID).
Content("Project Alpha notes").
Tags("work/project-alpha", "active"))
// Archived memos (user prefixed tags with "archive")
tc.CreateMemo(NewMemoBuilder("memo-old-homelab", tc.User.ID).
Content("Old homelab setup").
Tags("archive/homelab/old-server", "done"))
tc.CreateMemo(NewMemoBuilder("memo-old-project", tc.User.ID).
Content("Old project beta").
Tags("archive/work/project-beta", "completed"))
tc.CreateMemo(NewMemoBuilder("memo-archived-personal", tc.User.ID).
Content("Archived personal note").
Tags("archive/personal/2024", "old"))
// Test: Filter out ALL archived memos using startsWith
memos := tc.ListWithFilter(`!tags.exists(t, t.startsWith("archive"))`)
require.Len(t, memos, 2, "Should only show active memos (not archived)")
for _, memo := range memos {
for _, tag := range memo.Payload.Tags {
require.NotContains(t, tag, "archive", "Active memos should not have archive prefix")
}
}
// Test: Show ONLY archived memos
memos = tc.ListWithFilter(`tags.exists(t, t.startsWith("archive"))`)
require.Len(t, memos, 3, "Should find all archived memos")
for _, memo := range memos {
hasArchiveTag := false
for _, tag := range memo.Payload.Tags {
if len(tag) >= 7 && tag[:7] == "archive" {
hasArchiveTag = true
break
}
}
require.True(t, hasArchiveTag, "All returned memos should have archive prefix")
}
// Test: Filter archived homelab memos specifically
memos = tc.ListWithFilter(`tags.exists(t, t.startsWith("archive/homelab"))`)
require.Len(t, memos, 1, "Should find only archived homelab memos")
}
// =============================================================================
@ -552,6 +841,7 @@ func TestMemoFilterComplexLogical(t *testing.T) {
// =============================================================================
func TestMemoFilterMultipleFilters(t *testing.T) {
t.Parallel()
tc := NewMemoFilterTestContext(t)
defer tc.Close()
@ -570,6 +860,7 @@ func TestMemoFilterMultipleFilters(t *testing.T) {
// =============================================================================
func TestMemoFilterNullPayload(t *testing.T) {
t.Parallel()
tc := NewMemoFilterTestContext(t)
defer tc.Close()
@ -581,6 +872,7 @@ func TestMemoFilterNullPayload(t *testing.T) {
}
func TestMemoFilterNoMatches(t *testing.T) {
t.Parallel()
tc := NewMemoFilterTestContext(t)
defer tc.Close()
@ -589,3 +881,48 @@ func TestMemoFilterNoMatches(t *testing.T) {
memos := tc.ListWithFilter(`content.contains("nonexistent12345")`)
require.Len(t, memos, 0)
}
func TestMemoFilterJSONBooleanLogic(t *testing.T) {
t.Parallel()
tc := NewMemoFilterTestContext(t)
defer tc.Close()
// 1. Memo with task list (true) and NO link (null)
tc.CreateMemo(NewMemoBuilder("memo-task-only", tc.User.ID).
Content("Task only").
Property(func(p *storepb.MemoPayload_Property) { p.HasTaskList = true }))
// 2. Memo with link (true) and NO task list (null)
tc.CreateMemo(NewMemoBuilder("memo-link-only", tc.User.ID).
Content("Link only").
Property(func(p *storepb.MemoPayload_Property) { p.HasLink = true }))
// 3. Memo with both (true)
tc.CreateMemo(NewMemoBuilder("memo-both", tc.User.ID).
Content("Both").
Property(func(p *storepb.MemoPayload_Property) {
p.HasTaskList = true
p.HasLink = true
}))
// 4. Memo with neither (null)
tc.CreateMemo(NewMemoBuilder("memo-neither", tc.User.ID).Content("Neither"))
// Test A: has_task_list || has_link
// Expected: 3 memos (task-only, link-only, both). Neither should be excluded.
// This specifically tests the NULL handling in OR logic (NULL || TRUE should be TRUE)
memos := tc.ListWithFilter(`has_task_list || has_link`)
require.Len(t, memos, 3, "Should find 3 memos with OR logic")
// Test B: !has_task_list
// Expected: 2 memos (link-only, neither). Memos where has_task_list is NULL or FALSE.
// Note: If NULL is not handled, !NULL is still NULL (false-y in WHERE), so "neither" might be missed depending on logic.
// In our implementation, we want missing fields to behave as false.
memos = tc.ListWithFilter(`!has_task_list`)
require.Len(t, memos, 2, "Should find 2 memos where task list is false or missing")
// Test C: has_task_list && !has_link
// Expected: 1 memo (task-only).
memos = tc.ListWithFilter(`has_task_list && !has_link`)
require.Len(t, memos, 1, "Should find 1 memo (task only)")
}

View File

@ -10,6 +10,7 @@ import (
)
func TestMemoRelationStore(t *testing.T) {
t.Parallel()
ctx := context.Background()
ts := NewTestingStore(ctx, t)
user, err := createTestingHostUser(ctx, ts)
@ -62,6 +63,7 @@ func TestMemoRelationStore(t *testing.T) {
}
func TestMemoRelationListByMemoID(t *testing.T) {
t.Parallel()
ctx := context.Background()
ts := NewTestingStore(ctx, t)
user, err := createTestingHostUser(ctx, ts)
@ -136,6 +138,7 @@ func TestMemoRelationListByMemoID(t *testing.T) {
}
func TestMemoRelationDelete(t *testing.T) {
t.Parallel()
ctx := context.Background()
ts := NewTestingStore(ctx, t)
user, err := createTestingHostUser(ctx, ts)
@ -193,6 +196,7 @@ func TestMemoRelationDelete(t *testing.T) {
}
func TestMemoRelationDifferentTypes(t *testing.T) {
t.Parallel()
ctx := context.Background()
ts := NewTestingStore(ctx, t)
user, err := createTestingHostUser(ctx, ts)
@ -239,3 +243,440 @@ func TestMemoRelationDifferentTypes(t *testing.T) {
ts.Close()
}
func TestMemoRelationUpsertSameRelation(t *testing.T) {
t.Parallel()
ctx := context.Background()
ts := NewTestingStore(ctx, t)
user, err := createTestingHostUser(ctx, ts)
require.NoError(t, err)
mainMemo, err := ts.CreateMemo(ctx, &store.Memo{
UID: "main-memo",
CreatorID: user.ID,
Content: "main memo content",
Visibility: store.Public,
})
require.NoError(t, err)
relatedMemo, err := ts.CreateMemo(ctx, &store.Memo{
UID: "related-memo",
CreatorID: user.ID,
Content: "related memo content",
Visibility: store.Public,
})
require.NoError(t, err)
// Create relation
_, err = ts.UpsertMemoRelation(ctx, &store.MemoRelation{
MemoID: mainMemo.ID,
RelatedMemoID: relatedMemo.ID,
Type: store.MemoRelationReference,
})
require.NoError(t, err)
// Upsert the same relation again (should not create duplicate)
_, err = ts.UpsertMemoRelation(ctx, &store.MemoRelation{
MemoID: mainMemo.ID,
RelatedMemoID: relatedMemo.ID,
Type: store.MemoRelationReference,
})
require.NoError(t, err)
// Verify only one relation exists
relations, err := ts.ListMemoRelations(ctx, &store.FindMemoRelation{
MemoID: &mainMemo.ID,
})
require.NoError(t, err)
require.Len(t, relations, 1)
ts.Close()
}
func TestMemoRelationDeleteByType(t *testing.T) {
t.Parallel()
ctx := context.Background()
ts := NewTestingStore(ctx, t)
user, err := createTestingHostUser(ctx, ts)
require.NoError(t, err)
mainMemo, err := ts.CreateMemo(ctx, &store.Memo{
UID: "main-memo",
CreatorID: user.ID,
Content: "main memo content",
Visibility: store.Public,
})
require.NoError(t, err)
relatedMemo1, err := ts.CreateMemo(ctx, &store.Memo{
UID: "related-memo-1",
CreatorID: user.ID,
Content: "related memo 1 content",
Visibility: store.Public,
})
require.NoError(t, err)
relatedMemo2, err := ts.CreateMemo(ctx, &store.Memo{
UID: "related-memo-2",
CreatorID: user.ID,
Content: "related memo 2 content",
Visibility: store.Public,
})
require.NoError(t, err)
// Create reference relations
_, err = ts.UpsertMemoRelation(ctx, &store.MemoRelation{
MemoID: mainMemo.ID,
RelatedMemoID: relatedMemo1.ID,
Type: store.MemoRelationReference,
})
require.NoError(t, err)
// Create comment relation
_, err = ts.UpsertMemoRelation(ctx, &store.MemoRelation{
MemoID: mainMemo.ID,
RelatedMemoID: relatedMemo2.ID,
Type: store.MemoRelationComment,
})
require.NoError(t, err)
// Delete only reference type relations
refType := store.MemoRelationReference
err = ts.DeleteMemoRelation(ctx, &store.DeleteMemoRelation{
MemoID: &mainMemo.ID,
Type: &refType,
})
require.NoError(t, err)
// Verify only comment relation remains
relations, err := ts.ListMemoRelations(ctx, &store.FindMemoRelation{
MemoID: &mainMemo.ID,
})
require.NoError(t, err)
require.Len(t, relations, 1)
require.Equal(t, store.MemoRelationComment, relations[0].Type)
ts.Close()
}
func TestMemoRelationDeleteByMemoID(t *testing.T) {
t.Parallel()
ctx := context.Background()
ts := NewTestingStore(ctx, t)
user, err := createTestingHostUser(ctx, ts)
require.NoError(t, err)
memo1, err := ts.CreateMemo(ctx, &store.Memo{
UID: "memo-1",
CreatorID: user.ID,
Content: "memo 1 content",
Visibility: store.Public,
})
require.NoError(t, err)
memo2, err := ts.CreateMemo(ctx, &store.Memo{
UID: "memo-2",
CreatorID: user.ID,
Content: "memo 2 content",
Visibility: store.Public,
})
require.NoError(t, err)
relatedMemo, err := ts.CreateMemo(ctx, &store.Memo{
UID: "related-memo",
CreatorID: user.ID,
Content: "related memo content",
Visibility: store.Public,
})
require.NoError(t, err)
// Create relations for both memos
_, err = ts.UpsertMemoRelation(ctx, &store.MemoRelation{
MemoID: memo1.ID,
RelatedMemoID: relatedMemo.ID,
Type: store.MemoRelationReference,
})
require.NoError(t, err)
_, err = ts.UpsertMemoRelation(ctx, &store.MemoRelation{
MemoID: memo2.ID,
RelatedMemoID: relatedMemo.ID,
Type: store.MemoRelationReference,
})
require.NoError(t, err)
// Delete all relations for memo1
err = ts.DeleteMemoRelation(ctx, &store.DeleteMemoRelation{
MemoID: &memo1.ID,
})
require.NoError(t, err)
// Verify memo1's relations are gone
relations, err := ts.ListMemoRelations(ctx, &store.FindMemoRelation{
MemoID: &memo1.ID,
})
require.NoError(t, err)
require.Len(t, relations, 0)
// Verify memo2's relations still exist
relations, err = ts.ListMemoRelations(ctx, &store.FindMemoRelation{
MemoID: &memo2.ID,
})
require.NoError(t, err)
require.Len(t, relations, 1)
ts.Close()
}
func TestMemoRelationListByRelatedMemoID(t *testing.T) {
t.Parallel()
ctx := context.Background()
ts := NewTestingStore(ctx, t)
user, err := createTestingHostUser(ctx, ts)
require.NoError(t, err)
// Create a memo that will be referenced by others
targetMemo, err := ts.CreateMemo(ctx, &store.Memo{
UID: "target-memo",
CreatorID: user.ID,
Content: "target memo content",
Visibility: store.Public,
})
require.NoError(t, err)
// Create memos that reference the target
referrer1, err := ts.CreateMemo(ctx, &store.Memo{
UID: "referrer-1",
CreatorID: user.ID,
Content: "referrer 1 content",
Visibility: store.Public,
})
require.NoError(t, err)
referrer2, err := ts.CreateMemo(ctx, &store.Memo{
UID: "referrer-2",
CreatorID: user.ID,
Content: "referrer 2 content",
Visibility: store.Public,
})
require.NoError(t, err)
// Create relations pointing to target
_, err = ts.UpsertMemoRelation(ctx, &store.MemoRelation{
MemoID: referrer1.ID,
RelatedMemoID: targetMemo.ID,
Type: store.MemoRelationReference,
})
require.NoError(t, err)
_, err = ts.UpsertMemoRelation(ctx, &store.MemoRelation{
MemoID: referrer2.ID,
RelatedMemoID: targetMemo.ID,
Type: store.MemoRelationComment,
})
require.NoError(t, err)
// List by related memo ID (find all memos that reference the target)
relations, err := ts.ListMemoRelations(ctx, &store.FindMemoRelation{
RelatedMemoID: &targetMemo.ID,
})
require.NoError(t, err)
require.Len(t, relations, 2)
ts.Close()
}
func TestMemoRelationListCombinedFilters(t *testing.T) {
t.Parallel()
ctx := context.Background()
ts := NewTestingStore(ctx, t)
user, err := createTestingHostUser(ctx, ts)
require.NoError(t, err)
mainMemo, err := ts.CreateMemo(ctx, &store.Memo{
UID: "main-memo",
CreatorID: user.ID,
Content: "main memo content",
Visibility: store.Public,
})
require.NoError(t, err)
relatedMemo1, err := ts.CreateMemo(ctx, &store.Memo{
UID: "related-memo-1",
CreatorID: user.ID,
Content: "related memo 1 content",
Visibility: store.Public,
})
require.NoError(t, err)
relatedMemo2, err := ts.CreateMemo(ctx, &store.Memo{
UID: "related-memo-2",
CreatorID: user.ID,
Content: "related memo 2 content",
Visibility: store.Public,
})
require.NoError(t, err)
// Create multiple relations
_, err = ts.UpsertMemoRelation(ctx, &store.MemoRelation{
MemoID: mainMemo.ID,
RelatedMemoID: relatedMemo1.ID,
Type: store.MemoRelationReference,
})
require.NoError(t, err)
_, err = ts.UpsertMemoRelation(ctx, &store.MemoRelation{
MemoID: mainMemo.ID,
RelatedMemoID: relatedMemo2.ID,
Type: store.MemoRelationComment,
})
require.NoError(t, err)
// List with MemoID and Type filter
refType := store.MemoRelationReference
relations, err := ts.ListMemoRelations(ctx, &store.FindMemoRelation{
MemoID: &mainMemo.ID,
Type: &refType,
})
require.NoError(t, err)
require.Len(t, relations, 1)
require.Equal(t, relatedMemo1.ID, relations[0].RelatedMemoID)
// List with MemoID, RelatedMemoID, and Type filter
commentType := store.MemoRelationComment
relations, err = ts.ListMemoRelations(ctx, &store.FindMemoRelation{
MemoID: &mainMemo.ID,
RelatedMemoID: &relatedMemo2.ID,
Type: &commentType,
})
require.NoError(t, err)
require.Len(t, relations, 1)
ts.Close()
}
func TestMemoRelationListEmpty(t *testing.T) {
t.Parallel()
ctx := context.Background()
ts := NewTestingStore(ctx, t)
user, err := createTestingHostUser(ctx, ts)
require.NoError(t, err)
memo, err := ts.CreateMemo(ctx, &store.Memo{
UID: "memo-no-relations",
CreatorID: user.ID,
Content: "memo with no relations",
Visibility: store.Public,
})
require.NoError(t, err)
// List relations for memo with none
relations, err := ts.ListMemoRelations(ctx, &store.FindMemoRelation{
MemoID: &memo.ID,
})
require.NoError(t, err)
require.Len(t, relations, 0)
ts.Close()
}
func TestMemoRelationBidirectional(t *testing.T) {
t.Parallel()
ctx := context.Background()
ts := NewTestingStore(ctx, t)
user, err := createTestingHostUser(ctx, ts)
require.NoError(t, err)
memoA, err := ts.CreateMemo(ctx, &store.Memo{
UID: "memo-a",
CreatorID: user.ID,
Content: "memo A content",
Visibility: store.Public,
})
require.NoError(t, err)
memoB, err := ts.CreateMemo(ctx, &store.Memo{
UID: "memo-b",
CreatorID: user.ID,
Content: "memo B content",
Visibility: store.Public,
})
require.NoError(t, err)
// Create relation A -> B
_, err = ts.UpsertMemoRelation(ctx, &store.MemoRelation{
MemoID: memoA.ID,
RelatedMemoID: memoB.ID,
Type: store.MemoRelationReference,
})
require.NoError(t, err)
// Create relation B -> A (reverse direction)
_, err = ts.UpsertMemoRelation(ctx, &store.MemoRelation{
MemoID: memoB.ID,
RelatedMemoID: memoA.ID,
Type: store.MemoRelationReference,
})
require.NoError(t, err)
// Verify A -> B exists
relationsFromA, err := ts.ListMemoRelations(ctx, &store.FindMemoRelation{
MemoID: &memoA.ID,
})
require.NoError(t, err)
require.Len(t, relationsFromA, 1)
require.Equal(t, memoB.ID, relationsFromA[0].RelatedMemoID)
// Verify B -> A exists
relationsFromB, err := ts.ListMemoRelations(ctx, &store.FindMemoRelation{
MemoID: &memoB.ID,
})
require.NoError(t, err)
require.Len(t, relationsFromB, 1)
require.Equal(t, memoA.ID, relationsFromB[0].RelatedMemoID)
ts.Close()
}
func TestMemoRelationMultipleRelationsToSameMemo(t *testing.T) {
t.Parallel()
ctx := context.Background()
ts := NewTestingStore(ctx, t)
user, err := createTestingHostUser(ctx, ts)
require.NoError(t, err)
mainMemo, err := ts.CreateMemo(ctx, &store.Memo{
UID: "main-memo",
CreatorID: user.ID,
Content: "main memo content",
Visibility: store.Public,
})
require.NoError(t, err)
// Create multiple memos that all relate to the main memo
for i := 1; i <= 5; i++ {
relatedMemo, err := ts.CreateMemo(ctx, &store.Memo{
UID: "related-memo-" + string(rune('0'+i)),
CreatorID: user.ID,
Content: "related memo content",
Visibility: store.Public,
})
require.NoError(t, err)
_, err = ts.UpsertMemoRelation(ctx, &store.MemoRelation{
MemoID: mainMemo.ID,
RelatedMemoID: relatedMemo.ID,
Type: store.MemoRelationReference,
})
require.NoError(t, err)
}
// Verify all 5 relations exist
relations, err := ts.ListMemoRelations(ctx, &store.FindMemoRelation{
MemoID: &mainMemo.ID,
})
require.NoError(t, err)
require.Len(t, relations, 5)
ts.Close()
}

View File

@ -13,6 +13,7 @@ import (
)
func TestMemoStore(t *testing.T) {
t.Parallel()
ctx := context.Background()
ts := NewTestingStore(ctx, t)
user, err := createTestingHostUser(ctx, ts)
@ -64,6 +65,7 @@ func TestMemoStore(t *testing.T) {
}
func TestMemoListByTags(t *testing.T) {
t.Parallel()
ctx := context.Background()
ts := NewTestingStore(ctx, t)
user, err := createTestingHostUser(ctx, ts)
@ -96,6 +98,7 @@ func TestMemoListByTags(t *testing.T) {
}
func TestDeleteMemoStore(t *testing.T) {
t.Parallel()
ctx := context.Background()
ts := NewTestingStore(ctx, t)
user, err := createTestingHostUser(ctx, ts)
@ -117,6 +120,7 @@ func TestDeleteMemoStore(t *testing.T) {
}
func TestMemoGetByID(t *testing.T) {
t.Parallel()
ctx := context.Background()
ts := NewTestingStore(ctx, t)
user, err := createTestingHostUser(ctx, ts)
@ -147,6 +151,7 @@ func TestMemoGetByID(t *testing.T) {
}
func TestMemoGetByUID(t *testing.T) {
t.Parallel()
ctx := context.Background()
ts := NewTestingStore(ctx, t)
user, err := createTestingHostUser(ctx, ts)
@ -177,6 +182,7 @@ func TestMemoGetByUID(t *testing.T) {
}
func TestMemoListByVisibility(t *testing.T) {
t.Parallel()
ctx := context.Background()
ts := NewTestingStore(ctx, t)
user, err := createTestingHostUser(ctx, ts)
@ -239,6 +245,7 @@ func TestMemoListByVisibility(t *testing.T) {
}
func TestMemoListWithPagination(t *testing.T) {
t.Parallel()
ctx := context.Background()
ts := NewTestingStore(ctx, t)
user, err := createTestingHostUser(ctx, ts)
@ -274,6 +281,7 @@ func TestMemoListWithPagination(t *testing.T) {
}
func TestMemoUpdatePinned(t *testing.T) {
t.Parallel()
ctx := context.Background()
ts := NewTestingStore(ctx, t)
user, err := createTestingHostUser(ctx, ts)
@ -317,6 +325,7 @@ func TestMemoUpdatePinned(t *testing.T) {
}
func TestMemoUpdateVisibility(t *testing.T) {
t.Parallel()
ctx := context.Background()
ts := NewTestingStore(ctx, t)
user, err := createTestingHostUser(ctx, ts)
@ -359,6 +368,7 @@ func TestMemoUpdateVisibility(t *testing.T) {
}
func TestMemoInvalidUID(t *testing.T) {
t.Parallel()
ctx := context.Background()
ts := NewTestingStore(ctx, t)
user, err := createTestingHostUser(ctx, ts)
@ -378,6 +388,7 @@ func TestMemoInvalidUID(t *testing.T) {
}
func TestMemoWithPayload(t *testing.T) {
t.Parallel()
ctx := context.Background()
ts := NewTestingStore(ctx, t)
user, err := createTestingHostUser(ctx, ts)

View File

@ -2,16 +2,199 @@ package test
import (
"context"
"fmt"
"os"
"testing"
"time"
"github.com/stretchr/testify/require"
"github.com/usememos/memos/store"
)
func TestGetCurrentSchemaVersion(t *testing.T) {
// TestFreshInstall verifies that LATEST.sql applies correctly on a fresh database.
// This is essentially what NewTestingStore already does, but we make it explicit.
func TestFreshInstall(t *testing.T) {
t.Parallel()
ctx := context.Background()
// NewTestingStore creates a fresh database and runs Migrate()
// which applies LATEST.sql for uninitialized databases
ts := NewTestingStore(ctx, t)
// Verify migration completed successfully
currentSchemaVersion, err := ts.GetCurrentSchemaVersion()
require.NoError(t, err)
require.NotEmpty(t, currentSchemaVersion, "schema version should be set after fresh install")
// Verify we can read instance settings (basic sanity check)
instanceSetting, err := ts.GetInstanceBasicSetting(ctx)
require.NoError(t, err)
require.Equal(t, currentSchemaVersion, instanceSetting.SchemaVersion)
}
// TestMigrationReRun verifies that re-running the migration on an already
// migrated database does not fail or cause issues. This simulates a
// scenario where the server is restarted.
func TestMigrationReRun(t *testing.T) {
t.Parallel()
ctx := context.Background()
// Use the shared testing store which already runs migrations on init
ts := NewTestingStore(ctx, t)
// Get current version
initialVersion, err := ts.GetCurrentSchemaVersion()
require.NoError(t, err)
// Manually trigger migration again
err = ts.Migrate(ctx)
require.NoError(t, err, "re-running migration should not fail")
// Verify version hasn't changed (or at least is valid)
finalVersion, err := ts.GetCurrentSchemaVersion()
require.NoError(t, err)
require.Equal(t, initialVersion, finalVersion, "version should match after re-run")
}
// TestMigrationWithData verifies that migration preserves data integrity.
// Creates data, then re-runs migration and verifies data is still accessible.
func TestMigrationWithData(t *testing.T) {
t.Parallel()
ctx := context.Background()
ts := NewTestingStore(ctx, t)
currentSchemaVersion, err := ts.GetCurrentSchemaVersion()
require.NoError(t, err)
require.Equal(t, "0.25.1", currentSchemaVersion)
// Create a user and memo before re-running migration
user, err := createTestingHostUser(ctx, ts)
require.NoError(t, err, "should create user")
originalMemo, err := ts.CreateMemo(ctx, &store.Memo{
UID: "migration-data-test",
CreatorID: user.ID,
Content: "Data before migration re-run",
Visibility: store.Public,
})
require.NoError(t, err, "should create memo")
// Re-run migration
err = ts.Migrate(ctx)
require.NoError(t, err, "re-running migration should not fail")
// Verify data is still accessible
memo, err := ts.GetMemo(ctx, &store.FindMemo{UID: &originalMemo.UID})
require.NoError(t, err, "should retrieve memo after migration")
require.Equal(t, "Data before migration re-run", memo.Content, "memo content should be preserved")
}
// TestMigrationMultipleReRuns verifies that migration is idempotent
// even when run multiple times in succession.
func TestMigrationMultipleReRuns(t *testing.T) {
t.Parallel()
ctx := context.Background()
ts := NewTestingStore(ctx, t)
// Get initial version
initialVersion, err := ts.GetCurrentSchemaVersion()
require.NoError(t, err)
// Run migration multiple times
for i := 0; i < 3; i++ {
err = ts.Migrate(ctx)
require.NoError(t, err, "migration run %d should not fail", i+1)
}
// Verify version is still correct
finalVersion, err := ts.GetCurrentSchemaVersion()
require.NoError(t, err)
require.Equal(t, initialVersion, finalVersion, "version should remain unchanged after multiple re-runs")
}
// TestMigrationFromStableVersion verifies that upgrading from a stable Memos version
// to the current version works correctly. This is the critical upgrade path test.
//
// Test flow:
// 1. Start a stable Memos container to create a database with the old schema
// 2. Stop the container and wait for cleanup
// 3. Use the store directly to run migration with current code
// 4. Verify the migration succeeded and data can be written
//
// Note: This test is skipped when running with -race flag because testcontainers
// has known race conditions in its reaper code that are outside our control.
func TestMigrationFromStableVersion(t *testing.T) {
// Skip for non-SQLite drivers (simplifies the test)
if getDriverFromEnv() != "sqlite" {
t.Skip("skipping upgrade test for non-sqlite driver")
}
// Skip if explicitly disabled (e.g., in environments without Docker)
if os.Getenv("SKIP_CONTAINER_TESTS") == "1" {
t.Skip("skipping container-based test (SKIP_CONTAINER_TESTS=1)")
}
ctx := context.Background()
dataDir := t.TempDir()
// 1. Start stable Memos container to create database with old schema
cfg := MemosContainerConfig{
Driver: "sqlite",
DataDir: dataDir,
Version: StableMemosVersion,
}
t.Logf("Starting Memos %s container to create old-schema database...", cfg.Version)
container, err := StartMemosContainer(ctx, cfg)
require.NoError(t, err, "failed to start stable memos container")
// Wait for the container to fully initialize the database
time.Sleep(10 * time.Second)
// Stop the container gracefully
t.Log("Stopping stable Memos container...")
err = container.Terminate(ctx)
require.NoError(t, err, "failed to stop memos container")
// Wait for file handles to be released
time.Sleep(2 * time.Second)
// 2. Connect to the database directly and run migration with current code
dsn := fmt.Sprintf("%s/memos_prod.db", dataDir)
t.Logf("Connecting to database at %s...", dsn)
ts := NewTestingStoreWithDSN(ctx, t, "sqlite", dsn)
// Get the schema version before migration
oldSetting, err := ts.GetInstanceBasicSetting(ctx)
require.NoError(t, err)
t.Logf("Old schema version: %s", oldSetting.SchemaVersion)
// 3. Run migration with current code
t.Log("Running migration with current code...")
err = ts.Migrate(ctx)
require.NoError(t, err, "migration from stable version should succeed")
// 4. Verify migration succeeded
newVersion, err := ts.GetCurrentSchemaVersion()
require.NoError(t, err)
t.Logf("New schema version: %s", newVersion)
newSetting, err := ts.GetInstanceBasicSetting(ctx)
require.NoError(t, err)
require.Equal(t, newVersion, newSetting.SchemaVersion, "schema version should be updated")
// Verify we can write data to the migrated database
user, err := createTestingHostUser(ctx, ts)
require.NoError(t, err, "should create user after migration")
memo, err := ts.CreateMemo(ctx, &store.Memo{
UID: "post-upgrade-memo",
CreatorID: user.ID,
Content: "Content after upgrade from stable",
Visibility: store.Public,
})
require.NoError(t, err, "should create memo after migration")
require.Equal(t, "Content after upgrade from stable", memo.Content)
t.Logf("Migration successful: %s -> %s", oldSetting.SchemaVersion, newVersion)
}

View File

@ -10,6 +10,7 @@ import (
)
func TestReactionStore(t *testing.T) {
t.Parallel()
ctx := context.Background()
ts := NewTestingStore(ctx, t)
@ -67,6 +68,7 @@ func TestReactionStore(t *testing.T) {
}
func TestReactionListByCreatorID(t *testing.T) {
t.Parallel()
ctx := context.Background()
ts := NewTestingStore(ctx, t)
@ -113,6 +115,7 @@ func TestReactionListByCreatorID(t *testing.T) {
}
func TestReactionMultipleContentIDs(t *testing.T) {
t.Parallel()
ctx := context.Background()
ts := NewTestingStore(ctx, t)
@ -148,6 +151,7 @@ func TestReactionMultipleContentIDs(t *testing.T) {
}
func TestReactionUpsertDifferentTypes(t *testing.T) {
t.Parallel()
ctx := context.Background()
ts := NewTestingStore(ctx, t)

View File

@ -37,6 +37,27 @@ func NewTestingStore(ctx context.Context, t *testing.T) *store.Store {
return store
}
// NewTestingStoreWithDSN creates a testing store connected to a specific DSN.
// This is useful for testing migrations on existing data.
func NewTestingStoreWithDSN(_ context.Context, t *testing.T, driver, dsn string) *store.Store {
profile := &profile.Profile{
Port: getUnusedPort(),
Data: t.TempDir(), // Dummy dir, DSN matters
DSN: dsn,
Driver: driver,
Version: version.GetCurrentVersion(),
}
dbDriver, err := db.NewDBDriver(profile)
if err != nil {
t.Fatalf("failed to create db driver: %v", err)
}
store := store.New(dbDriver, profile)
// Do not run Migrate() automatically, as we might be testing pre-migration state
// or want to run it manually.
return store
}
func getUnusedPort() int {
// Get a random unused port
listener, err := net.Listen("tcp", "localhost:0")
@ -73,12 +94,11 @@ func getTestingProfileForDriver(t *testing.T, driver string) *profile.Profile {
}
return &profile.Profile{
Mode: mode,
Port: port,
Data: dir,
DSN: dsn,
Driver: driver,
Version: version.GetCurrentVersion(mode),
Version: version.GetCurrentVersion(),
}
}

View File

@ -2,15 +2,18 @@ package test
import (
"context"
"strings"
"testing"
"github.com/stretchr/testify/require"
"google.golang.org/protobuf/types/known/timestamppb"
storepb "github.com/usememos/memos/proto/gen/store"
"github.com/usememos/memos/store"
)
func TestUserSettingStore(t *testing.T) {
t.Parallel()
ctx := context.Background()
ts := NewTestingStore(ctx, t)
user, err := createTestingHostUser(ctx, ts)
@ -28,6 +31,7 @@ func TestUserSettingStore(t *testing.T) {
}
func TestUserSettingGetByUserID(t *testing.T) {
t.Parallel()
ctx := context.Background()
ts := NewTestingStore(ctx, t)
user, err := createTestingHostUser(ctx, ts)
@ -62,6 +66,7 @@ func TestUserSettingGetByUserID(t *testing.T) {
}
func TestUserSettingUpsertUpdate(t *testing.T) {
t.Parallel()
ctx := context.Background()
ts := NewTestingStore(ctx, t)
user, err := createTestingHostUser(ctx, ts)
@ -100,6 +105,7 @@ func TestUserSettingUpsertUpdate(t *testing.T) {
}
func TestUserSettingRefreshTokens(t *testing.T) {
t.Parallel()
ctx := context.Background()
ts := NewTestingStore(ctx, t)
user, err := createTestingHostUser(ctx, ts)
@ -160,6 +166,7 @@ func TestUserSettingRefreshTokens(t *testing.T) {
}
func TestUserSettingPersonalAccessTokens(t *testing.T) {
t.Parallel()
ctx := context.Background()
ts := NewTestingStore(ctx, t)
user, err := createTestingHostUser(ctx, ts)
@ -211,6 +218,7 @@ func TestUserSettingPersonalAccessTokens(t *testing.T) {
}
func TestUserSettingWebhooks(t *testing.T) {
t.Parallel()
ctx := context.Background()
ts := NewTestingStore(ctx, t)
user, err := createTestingHostUser(ctx, ts)
@ -277,6 +285,7 @@ func TestUserSettingWebhooks(t *testing.T) {
}
func TestUserSettingShortcuts(t *testing.T) {
t.Parallel()
ctx := context.Background()
ts := NewTestingStore(ctx, t)
user, err := createTestingHostUser(ctx, ts)
@ -308,3 +317,559 @@ func TestUserSettingShortcuts(t *testing.T) {
ts.Close()
}
func TestUserSettingGetUserByPATHash(t *testing.T) {
t.Parallel()
ctx := context.Background()
ts := NewTestingStore(ctx, t)
user, err := createTestingHostUser(ctx, ts)
require.NoError(t, err)
// Create a PAT with a known hash
patHash := "test-pat-hash-12345"
pat := &storepb.PersonalAccessTokensUserSetting_PersonalAccessToken{
TokenId: "pat-test-1",
TokenHash: patHash,
Description: "Test PAT for lookup",
}
err = ts.AddUserPersonalAccessToken(ctx, user.ID, pat)
require.NoError(t, err)
// Lookup user by PAT hash
result, err := ts.GetUserByPATHash(ctx, patHash)
require.NoError(t, err)
require.NotNil(t, result)
require.Equal(t, user.ID, result.UserID)
require.NotNil(t, result.User)
require.Equal(t, user.Username, result.User.Username)
require.NotNil(t, result.PAT)
require.Equal(t, "pat-test-1", result.PAT.TokenId)
require.Equal(t, "Test PAT for lookup", result.PAT.Description)
ts.Close()
}
func TestUserSettingGetUserByPATHashNotFound(t *testing.T) {
t.Parallel()
ctx := context.Background()
ts := NewTestingStore(ctx, t)
_, err := createTestingHostUser(ctx, ts)
require.NoError(t, err)
// Lookup non-existent PAT hash
result, err := ts.GetUserByPATHash(ctx, "non-existent-hash")
require.Error(t, err)
require.Nil(t, result)
ts.Close()
}
func TestUserSettingGetUserByPATHashMultipleUsers(t *testing.T) {
t.Parallel()
ctx := context.Background()
ts := NewTestingStore(ctx, t)
user1, err := createTestingHostUser(ctx, ts)
require.NoError(t, err)
user2, err := createTestingUserWithRole(ctx, ts, "user2", store.RoleUser)
require.NoError(t, err)
// Create PATs for both users
pat1Hash := "user1-pat-hash"
err = ts.AddUserPersonalAccessToken(ctx, user1.ID, &storepb.PersonalAccessTokensUserSetting_PersonalAccessToken{
TokenId: "pat-user1",
TokenHash: pat1Hash,
Description: "User 1 PAT",
})
require.NoError(t, err)
pat2Hash := "user2-pat-hash"
err = ts.AddUserPersonalAccessToken(ctx, user2.ID, &storepb.PersonalAccessTokensUserSetting_PersonalAccessToken{
TokenId: "pat-user2",
TokenHash: pat2Hash,
Description: "User 2 PAT",
})
require.NoError(t, err)
// Lookup user1's PAT
result1, err := ts.GetUserByPATHash(ctx, pat1Hash)
require.NoError(t, err)
require.Equal(t, user1.ID, result1.UserID)
require.Equal(t, user1.Username, result1.User.Username)
// Lookup user2's PAT
result2, err := ts.GetUserByPATHash(ctx, pat2Hash)
require.NoError(t, err)
require.Equal(t, user2.ID, result2.UserID)
require.Equal(t, user2.Username, result2.User.Username)
ts.Close()
}
func TestUserSettingGetUserByPATHashMultiplePATsSameUser(t *testing.T) {
t.Parallel()
ctx := context.Background()
ts := NewTestingStore(ctx, t)
user, err := createTestingHostUser(ctx, ts)
require.NoError(t, err)
// Create multiple PATs for the same user
pat1Hash := "first-pat-hash"
err = ts.AddUserPersonalAccessToken(ctx, user.ID, &storepb.PersonalAccessTokensUserSetting_PersonalAccessToken{
TokenId: "pat-1",
TokenHash: pat1Hash,
Description: "First PAT",
})
require.NoError(t, err)
pat2Hash := "second-pat-hash"
err = ts.AddUserPersonalAccessToken(ctx, user.ID, &storepb.PersonalAccessTokensUserSetting_PersonalAccessToken{
TokenId: "pat-2",
TokenHash: pat2Hash,
Description: "Second PAT",
})
require.NoError(t, err)
// Both PATs should resolve to the same user
result1, err := ts.GetUserByPATHash(ctx, pat1Hash)
require.NoError(t, err)
require.Equal(t, user.ID, result1.UserID)
require.Equal(t, "pat-1", result1.PAT.TokenId)
result2, err := ts.GetUserByPATHash(ctx, pat2Hash)
require.NoError(t, err)
require.Equal(t, user.ID, result2.UserID)
require.Equal(t, "pat-2", result2.PAT.TokenId)
ts.Close()
}
func TestUserSettingUpdatePATLastUsed(t *testing.T) {
t.Parallel()
ctx := context.Background()
ts := NewTestingStore(ctx, t)
user, err := createTestingHostUser(ctx, ts)
require.NoError(t, err)
// Create a PAT
patHash := "pat-hash-for-update"
err = ts.AddUserPersonalAccessToken(ctx, user.ID, &storepb.PersonalAccessTokensUserSetting_PersonalAccessToken{
TokenId: "pat-update-test",
TokenHash: patHash,
Description: "PAT for update test",
})
require.NoError(t, err)
// Update last used timestamp
now := timestamppb.Now()
err = ts.UpdatePATLastUsed(ctx, user.ID, "pat-update-test", now)
require.NoError(t, err)
// Verify the update
pats, err := ts.GetUserPersonalAccessTokens(ctx, user.ID)
require.NoError(t, err)
require.Len(t, pats, 1)
require.NotNil(t, pats[0].LastUsedAt)
ts.Close()
}
func TestUserSettingGetUserByPATHashWithExpiredToken(t *testing.T) {
t.Parallel()
ctx := context.Background()
ts := NewTestingStore(ctx, t)
user, err := createTestingHostUser(ctx, ts)
require.NoError(t, err)
// Create a PAT with expiration info
patHash := "pat-hash-with-expiry"
expiresAt := timestamppb.Now()
pat := &storepb.PersonalAccessTokensUserSetting_PersonalAccessToken{
TokenId: "pat-expiry-test",
TokenHash: patHash,
Description: "PAT with expiry",
ExpiresAt: expiresAt,
}
err = ts.AddUserPersonalAccessToken(ctx, user.ID, pat)
require.NoError(t, err)
// Should still be able to look up by hash (expiry check is done at auth level)
result, err := ts.GetUserByPATHash(ctx, patHash)
require.NoError(t, err)
require.NotNil(t, result)
require.Equal(t, user.ID, result.UserID)
require.NotNil(t, result.PAT.ExpiresAt)
ts.Close()
}
func TestUserSettingGetUserByPATHashAfterRemoval(t *testing.T) {
t.Parallel()
ctx := context.Background()
ts := NewTestingStore(ctx, t)
user, err := createTestingHostUser(ctx, ts)
require.NoError(t, err)
// Create a PAT
patHash := "pat-hash-to-remove"
err = ts.AddUserPersonalAccessToken(ctx, user.ID, &storepb.PersonalAccessTokensUserSetting_PersonalAccessToken{
TokenId: "pat-remove-test",
TokenHash: patHash,
Description: "PAT to be removed",
})
require.NoError(t, err)
// Verify it exists
result, err := ts.GetUserByPATHash(ctx, patHash)
require.NoError(t, err)
require.NotNil(t, result)
// Remove the PAT
err = ts.RemoveUserPersonalAccessToken(ctx, user.ID, "pat-remove-test")
require.NoError(t, err)
// Should no longer be found
result, err = ts.GetUserByPATHash(ctx, patHash)
require.Error(t, err)
require.Nil(t, result)
ts.Close()
}
func TestUserSettingGetUserByPATHashSpecialCharacters(t *testing.T) {
t.Parallel()
ctx := context.Background()
ts := NewTestingStore(ctx, t)
user, err := createTestingHostUser(ctx, ts)
require.NoError(t, err)
// Create PATs with special characters in hash (simulating real hash values)
testCases := []struct {
tokenID string
tokenHash string
}{
{"pat-special-1", "abc123+/=XYZ"},
{"pat-special-2", "sha256:abcdef1234567890"},
{"pat-special-3", "$2a$10$N9qo8uLOickgx2ZMRZoMy"},
}
for _, tc := range testCases {
err = ts.AddUserPersonalAccessToken(ctx, user.ID, &storepb.PersonalAccessTokensUserSetting_PersonalAccessToken{
TokenId: tc.tokenID,
TokenHash: tc.tokenHash,
Description: "PAT with special chars",
})
require.NoError(t, err)
// Verify lookup works with special characters
result, err := ts.GetUserByPATHash(ctx, tc.tokenHash)
require.NoError(t, err)
require.NotNil(t, result)
require.Equal(t, tc.tokenID, result.PAT.TokenId)
}
ts.Close()
}
func TestUserSettingGetUserByPATHashLargeTokenCount(t *testing.T) {
t.Parallel()
ctx := context.Background()
ts := NewTestingStore(ctx, t)
user, err := createTestingHostUser(ctx, ts)
require.NoError(t, err)
// Create many PATs for the same user
tokenCount := 10
hashes := make([]string, tokenCount)
for i := 0; i < tokenCount; i++ {
hashes[i] = "pat-hash-" + string(rune('A'+i)) + "-large-test"
err = ts.AddUserPersonalAccessToken(ctx, user.ID, &storepb.PersonalAccessTokensUserSetting_PersonalAccessToken{
TokenId: "pat-large-" + string(rune('A'+i)),
TokenHash: hashes[i],
Description: "PAT for large count test",
})
require.NoError(t, err)
}
// Verify each hash can be looked up
for i, hash := range hashes {
result, err := ts.GetUserByPATHash(ctx, hash)
require.NoError(t, err)
require.NotNil(t, result)
require.Equal(t, user.ID, result.UserID)
require.Equal(t, "pat-large-"+string(rune('A'+i)), result.PAT.TokenId)
}
ts.Close()
}
func TestUserSettingMultipleSettingTypes(t *testing.T) {
t.Parallel()
ctx := context.Background()
ts := NewTestingStore(ctx, t)
user, err := createTestingHostUser(ctx, ts)
require.NoError(t, err)
// Create GENERAL setting
_, err = ts.UpsertUserSetting(ctx, &storepb.UserSetting{
UserId: user.ID,
Key: storepb.UserSetting_GENERAL,
Value: &storepb.UserSetting_General{General: &storepb.GeneralUserSetting{Locale: "ja"}},
})
require.NoError(t, err)
// Create SHORTCUTS setting
_, err = ts.UpsertUserSetting(ctx, &storepb.UserSetting{
UserId: user.ID,
Key: storepb.UserSetting_SHORTCUTS,
Value: &storepb.UserSetting_Shortcuts{Shortcuts: &storepb.ShortcutsUserSetting{
Shortcuts: []*storepb.ShortcutsUserSetting_Shortcut{
{Id: "s1", Title: "Shortcut 1"},
},
}},
})
require.NoError(t, err)
// Add a PAT
err = ts.AddUserPersonalAccessToken(ctx, user.ID, &storepb.PersonalAccessTokensUserSetting_PersonalAccessToken{
TokenId: "pat-multi",
TokenHash: "hash-multi",
})
require.NoError(t, err)
// List all settings for user
settings, err := ts.ListUserSettings(ctx, &store.FindUserSetting{UserID: &user.ID})
require.NoError(t, err)
require.Len(t, settings, 3)
// Verify each setting type
generalSetting, err := ts.GetUserSetting(ctx, &store.FindUserSetting{UserID: &user.ID, Key: storepb.UserSetting_GENERAL})
require.NoError(t, err)
require.Equal(t, "ja", generalSetting.GetGeneral().Locale)
shortcutsSetting, err := ts.GetUserSetting(ctx, &store.FindUserSetting{UserID: &user.ID, Key: storepb.UserSetting_SHORTCUTS})
require.NoError(t, err)
require.Len(t, shortcutsSetting.GetShortcuts().Shortcuts, 1)
patsSetting, err := ts.GetUserSetting(ctx, &store.FindUserSetting{UserID: &user.ID, Key: storepb.UserSetting_PERSONAL_ACCESS_TOKENS})
require.NoError(t, err)
require.Len(t, patsSetting.GetPersonalAccessTokens().Tokens, 1)
ts.Close()
}
func TestUserSettingShortcutsEdgeCases(t *testing.T) {
t.Parallel()
ctx := context.Background()
ts := NewTestingStore(ctx, t)
user, err := createTestingHostUser(ctx, ts)
require.NoError(t, err)
// Case 1: Special characters in Filter and Title
// Includes quotes, backslashes, newlines, and other JSON-sensitive characters
specialCharsFilter := `tag in ["work", "project"] && content.contains("urgent")`
specialCharsTitle := `Work "Urgent" \ Notes`
shortcuts := &storepb.ShortcutsUserSetting{
Shortcuts: []*storepb.ShortcutsUserSetting_Shortcut{
{Id: "s1", Title: specialCharsTitle, Filter: specialCharsFilter},
},
}
_, err = ts.UpsertUserSetting(ctx, &storepb.UserSetting{
UserId: user.ID,
Key: storepb.UserSetting_SHORTCUTS,
Value: &storepb.UserSetting_Shortcuts{Shortcuts: shortcuts},
})
require.NoError(t, err)
setting, err := ts.GetUserSetting(ctx, &store.FindUserSetting{
UserID: &user.ID,
Key: storepb.UserSetting_SHORTCUTS,
})
require.NoError(t, err)
require.NotNil(t, setting)
require.Len(t, setting.GetShortcuts().Shortcuts, 1)
require.Equal(t, specialCharsTitle, setting.GetShortcuts().Shortcuts[0].Title)
require.Equal(t, specialCharsFilter, setting.GetShortcuts().Shortcuts[0].Filter)
// Case 2: Unicode characters
unicodeFilter := `tag in ["你好", "世界"]`
unicodeTitle := `My 🚀 Shortcuts`
shortcuts = &storepb.ShortcutsUserSetting{
Shortcuts: []*storepb.ShortcutsUserSetting_Shortcut{
{Id: "s2", Title: unicodeTitle, Filter: unicodeFilter},
},
}
_, err = ts.UpsertUserSetting(ctx, &storepb.UserSetting{
UserId: user.ID,
Key: storepb.UserSetting_SHORTCUTS,
Value: &storepb.UserSetting_Shortcuts{Shortcuts: shortcuts},
})
require.NoError(t, err)
setting, err = ts.GetUserSetting(ctx, &store.FindUserSetting{
UserID: &user.ID,
Key: storepb.UserSetting_SHORTCUTS,
})
require.NoError(t, err)
require.NotNil(t, setting)
require.Len(t, setting.GetShortcuts().Shortcuts, 1)
require.Equal(t, unicodeTitle, setting.GetShortcuts().Shortcuts[0].Title)
require.Equal(t, unicodeFilter, setting.GetShortcuts().Shortcuts[0].Filter)
// Case 3: Empty shortcuts list
// Should allow saving an empty list (clearing shortcuts)
shortcuts = &storepb.ShortcutsUserSetting{
Shortcuts: []*storepb.ShortcutsUserSetting_Shortcut{},
}
_, err = ts.UpsertUserSetting(ctx, &storepb.UserSetting{
UserId: user.ID,
Key: storepb.UserSetting_SHORTCUTS,
Value: &storepb.UserSetting_Shortcuts{Shortcuts: shortcuts},
})
require.NoError(t, err)
setting, err = ts.GetUserSetting(ctx, &store.FindUserSetting{
UserID: &user.ID,
Key: storepb.UserSetting_SHORTCUTS,
})
require.NoError(t, err)
require.NotNil(t, setting)
require.NotNil(t, setting.GetShortcuts())
require.Len(t, setting.GetShortcuts().Shortcuts, 0)
// Case 4: Large filter string
// Test reasonable large string handling (e.g. 4KB)
largeFilter := strings.Repeat("tag:long_tag_name ", 200)
shortcuts = &storepb.ShortcutsUserSetting{
Shortcuts: []*storepb.ShortcutsUserSetting_Shortcut{
{Id: "s3", Title: "Large Filter", Filter: largeFilter},
},
}
_, err = ts.UpsertUserSetting(ctx, &storepb.UserSetting{
UserId: user.ID,
Key: storepb.UserSetting_SHORTCUTS,
Value: &storepb.UserSetting_Shortcuts{Shortcuts: shortcuts},
})
require.NoError(t, err)
setting, err = ts.GetUserSetting(ctx, &store.FindUserSetting{
UserID: &user.ID,
Key: storepb.UserSetting_SHORTCUTS,
})
require.NoError(t, err)
require.NotNil(t, setting)
require.Equal(t, largeFilter, setting.GetShortcuts().Shortcuts[0].Filter)
ts.Close()
}
func TestUserSettingShortcutsPartialUpdate(t *testing.T) {
t.Parallel()
ctx := context.Background()
ts := NewTestingStore(ctx, t)
user, err := createTestingHostUser(ctx, ts)
require.NoError(t, err)
// Initial set
shortcuts := &storepb.ShortcutsUserSetting{
Shortcuts: []*storepb.ShortcutsUserSetting_Shortcut{
{Id: "s1", Title: "Note 1", Filter: "tag:1"},
{Id: "s2", Title: "Note 2", Filter: "tag:2"},
},
}
_, err = ts.UpsertUserSetting(ctx, &storepb.UserSetting{
UserId: user.ID,
Key: storepb.UserSetting_SHORTCUTS,
Value: &storepb.UserSetting_Shortcuts{Shortcuts: shortcuts},
})
require.NoError(t, err)
// Update by replacing the whole list (Store Upsert replaces the value for the key)
// We want to verify that we can "update" a single item by sending the modified list
updatedShortcuts := &storepb.ShortcutsUserSetting{
Shortcuts: []*storepb.ShortcutsUserSetting_Shortcut{
{Id: "s1", Title: "Note 1 Updated", Filter: "tag:1_updated"},
{Id: "s2", Title: "Note 2", Filter: "tag:2"},
{Id: "s3", Title: "Note 3", Filter: "tag:3"}, // Add new one
},
}
_, err = ts.UpsertUserSetting(ctx, &storepb.UserSetting{
UserId: user.ID,
Key: storepb.UserSetting_SHORTCUTS,
Value: &storepb.UserSetting_Shortcuts{Shortcuts: updatedShortcuts},
})
require.NoError(t, err)
setting, err := ts.GetUserSetting(ctx, &store.FindUserSetting{
UserID: &user.ID,
Key: storepb.UserSetting_SHORTCUTS,
})
require.NoError(t, err)
require.NotNil(t, setting)
require.Len(t, setting.GetShortcuts().Shortcuts, 3)
// Verify updates
for _, s := range setting.GetShortcuts().Shortcuts {
if s.Id == "s1" {
require.Equal(t, "Note 1 Updated", s.Title)
require.Equal(t, "tag:1_updated", s.Filter)
} else if s.Id == "s2" {
require.Equal(t, "Note 2", s.Title)
} else if s.Id == "s3" {
require.Equal(t, "Note 3", s.Title)
}
}
ts.Close()
}
func TestUserSettingJSONFieldsEdgeCases(t *testing.T) {
t.Parallel()
ctx := context.Background()
ts := NewTestingStore(ctx, t)
user, err := createTestingHostUser(ctx, ts)
require.NoError(t, err)
// Case 1: Webhook with special characters and Unicode in Title and URL
specialWebhook := &storepb.WebhooksUserSetting_Webhook{
Id: "wh-special",
Title: `My "Special" & <Webhook> 🚀`,
Url: "https://example.com/hook?query=你好&param=\"value\"",
}
err = ts.AddUserWebhook(ctx, user.ID, specialWebhook)
require.NoError(t, err)
webhooks, err := ts.GetUserWebhooks(ctx, user.ID)
require.NoError(t, err)
require.Len(t, webhooks, 1)
require.Equal(t, specialWebhook.Title, webhooks[0].Title)
require.Equal(t, specialWebhook.Url, webhooks[0].Url)
// Case 2: PAT with special description
specialPAT := &storepb.PersonalAccessTokensUserSetting_PersonalAccessToken{
TokenId: "pat-special",
TokenHash: "hash-special",
Description: "Token for 'CLI' \n & \"API\" \t with unicode 🔑",
}
err = ts.AddUserPersonalAccessToken(ctx, user.ID, specialPAT)
require.NoError(t, err)
pats, err := ts.GetUserPersonalAccessTokens(ctx, user.ID)
require.NoError(t, err)
require.Len(t, pats, 1)
require.Equal(t, specialPAT.Description, pats[0].Description)
// Case 3: Refresh Token with special description
specialRefreshToken := &storepb.RefreshTokensUserSetting_RefreshToken{
TokenId: "rt-special",
Description: "Browser: Firefox (Nightly) / OS: Linux 🐧",
}
err = ts.AddUserRefreshToken(ctx, user.ID, specialRefreshToken)
require.NoError(t, err)
tokens, err := ts.GetUserRefreshTokens(ctx, user.ID)
require.NoError(t, err)
require.Len(t, tokens, 1)
require.Equal(t, specialRefreshToken.Description, tokens[0].Description)
ts.Close()
}

Some files were not shown because too many files have changed in this diff Show More