mirror of https://github.com/usememos/memos.git
Merge branch 'usememos:main' into feature/button-icon-divider
This commit is contained in:
commit
183fdb3c57
|
|
@ -1 +1,13 @@
|
|||
web/node_modules
|
||||
web/dist
|
||||
.git
|
||||
.github
|
||||
build/
|
||||
tmp/
|
||||
memos
|
||||
*.md
|
||||
.gitignore
|
||||
.golangci.yaml
|
||||
.dockerignore
|
||||
docs/
|
||||
.DS_Store
|
||||
|
|
@ -11,37 +11,79 @@ on:
|
|||
- "go.sum"
|
||||
- "**.go"
|
||||
|
||||
concurrency:
|
||||
group: ${{ github.workflow }}-${{ github.ref }}
|
||||
cancel-in-progress: true
|
||||
|
||||
jobs:
|
||||
go-static-checks:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v5
|
||||
- uses: actions/checkout@v6
|
||||
|
||||
- uses: actions/setup-go@v6
|
||||
with:
|
||||
go-version: 1.25
|
||||
check-latest: true
|
||||
cache: true
|
||||
cache-dependency-path: go.sum
|
||||
|
||||
- name: Verify go.mod is tidy
|
||||
run: |
|
||||
go mod tidy -go=1.25
|
||||
git diff --exit-code
|
||||
|
||||
- name: golangci-lint
|
||||
uses: golangci/golangci-lint-action@v8
|
||||
uses: golangci/golangci-lint-action@v9
|
||||
with:
|
||||
version: v2.4.0
|
||||
args: --verbose --timeout=3m
|
||||
skip-cache: true
|
||||
args: --timeout=3m
|
||||
|
||||
go-tests:
|
||||
runs-on: ubuntu-latest
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
test-group:
|
||||
- store
|
||||
- server
|
||||
- plugin
|
||||
- other
|
||||
steps:
|
||||
- uses: actions/checkout@v5
|
||||
- uses: actions/checkout@v6
|
||||
|
||||
- uses: actions/setup-go@v6
|
||||
with:
|
||||
go-version: 1.25
|
||||
check-latest: true
|
||||
cache: true
|
||||
- name: Run all tests
|
||||
run: go test -v ./... | tee test.log; exit ${PIPESTATUS[0]}
|
||||
- name: Pretty print tests running time
|
||||
run: grep --color=never -e '--- PASS:' -e '--- FAIL:' test.log | sed 's/[:()]//g' | awk '{print $2,$3,$4}' | sort -t' ' -nk3 -r | awk '{sum += $3; print $1,$2,$3,sum"s"}'
|
||||
cache-dependency-path: go.sum
|
||||
|
||||
- name: Run tests - ${{ matrix.test-group }}
|
||||
run: |
|
||||
case "${{ matrix.test-group }}" in
|
||||
store)
|
||||
# Run store tests for all drivers (sqlite, mysql, postgres)
|
||||
# The TestMain in store/test runs all drivers when DRIVER is not set
|
||||
# Note: We run without -race for container tests due to testcontainers race issues
|
||||
go test -v -coverprofile=coverage.out -covermode=atomic ./store/...
|
||||
;;
|
||||
server)
|
||||
go test -v -race -coverprofile=coverage.out -covermode=atomic ./server/...
|
||||
;;
|
||||
plugin)
|
||||
go test -v -race -coverprofile=coverage.out -covermode=atomic ./plugin/...
|
||||
;;
|
||||
other)
|
||||
go test -v -race -coverprofile=coverage.out -covermode=atomic \
|
||||
./cmd/... ./internal/... ./proto/...
|
||||
;;
|
||||
esac
|
||||
env:
|
||||
DRIVER: ${{ matrix.test-group == 'store' && '' || 'sqlite' }}
|
||||
|
||||
- name: Upload coverage
|
||||
if: github.event_name == 'push' && github.ref == 'refs/heads/main'
|
||||
uses: codecov/codecov-action@v5
|
||||
with:
|
||||
files: ./coverage.out
|
||||
flags: ${{ matrix.test-group }}
|
||||
fail_ci_if_error: false
|
||||
|
|
|
|||
|
|
@ -4,36 +4,128 @@ on:
|
|||
push:
|
||||
branches: [main]
|
||||
|
||||
env:
|
||||
DOCKER_PLATFORMS: |
|
||||
linux/amd64
|
||||
linux/arm64
|
||||
|
||||
concurrency:
|
||||
group: ${{ github.workflow }}-${{ github.repository }}
|
||||
cancel-in-progress: true
|
||||
|
||||
jobs:
|
||||
build-and-push-canary-image:
|
||||
build-frontend:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v6
|
||||
|
||||
- uses: pnpm/action-setup@v4.2.0
|
||||
with:
|
||||
version: 10
|
||||
- uses: actions/setup-node@v6
|
||||
with:
|
||||
node-version: "22"
|
||||
cache: pnpm
|
||||
cache-dependency-path: "web/pnpm-lock.yaml"
|
||||
- name: Get pnpm store directory
|
||||
id: pnpm-cache
|
||||
shell: bash
|
||||
run: echo "STORE_PATH=$(pnpm store path)" >> $GITHUB_OUTPUT
|
||||
- name: Setup pnpm cache
|
||||
uses: actions/cache@v5
|
||||
with:
|
||||
path: ${{ steps.pnpm-cache.outputs.STORE_PATH }}
|
||||
key: ${{ runner.os }}-pnpm-store-${{ hashFiles('web/pnpm-lock.yaml') }}
|
||||
restore-keys: ${{ runner.os }}-pnpm-store-
|
||||
- run: pnpm install --frozen-lockfile
|
||||
working-directory: web
|
||||
- name: Run frontend build
|
||||
run: pnpm release
|
||||
working-directory: web
|
||||
|
||||
- name: Upload frontend artifacts
|
||||
uses: actions/upload-artifact@v6
|
||||
with:
|
||||
name: frontend-dist
|
||||
path: server/router/frontend/dist
|
||||
retention-days: 1
|
||||
|
||||
build-push:
|
||||
needs: build-frontend
|
||||
runs-on: ubuntu-latest
|
||||
permissions:
|
||||
contents: read
|
||||
packages: write
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
platform:
|
||||
- linux/amd64
|
||||
- linux/arm64
|
||||
steps:
|
||||
- uses: actions/checkout@v6
|
||||
|
||||
- name: Download frontend artifacts
|
||||
uses: actions/download-artifact@v7
|
||||
with:
|
||||
name: frontend-dist
|
||||
path: server/router/frontend/dist
|
||||
|
||||
- name: Set up QEMU
|
||||
uses: docker/setup-qemu-action@v3
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@v3
|
||||
with:
|
||||
username: ${{ secrets.DOCKER_HUB_USERNAME }}
|
||||
password: ${{ secrets.DOCKER_HUB_TOKEN }}
|
||||
|
||||
- name: Login to GitHub Container Registry
|
||||
uses: docker/login-action@v3
|
||||
with:
|
||||
registry: ghcr.io
|
||||
username: ${{ github.actor }}
|
||||
password: ${{ github.token }}
|
||||
|
||||
- name: Build and push by digest
|
||||
id: build
|
||||
uses: docker/build-push-action@v6
|
||||
with:
|
||||
context: .
|
||||
file: ./scripts/Dockerfile
|
||||
platforms: ${{ matrix.platform }}
|
||||
cache-from: type=gha,scope=build-${{ matrix.platform }}
|
||||
cache-to: type=gha,mode=max,scope=build-${{ matrix.platform }}
|
||||
outputs: type=image,name=neosmemo/memos,push-by-digest=true,name-canonical=true,push=true
|
||||
|
||||
- name: Export digest
|
||||
run: |
|
||||
mkdir -p /tmp/digests
|
||||
digest="${{ steps.build.outputs.digest }}"
|
||||
touch "/tmp/digests/${digest#sha256:}"
|
||||
|
||||
- name: Upload digest
|
||||
uses: actions/upload-artifact@v6
|
||||
with:
|
||||
name: digests-${{ strategy.job-index }}
|
||||
path: /tmp/digests/*
|
||||
if-no-files-found: error
|
||||
retention-days: 1
|
||||
|
||||
merge:
|
||||
needs: build-push
|
||||
runs-on: ubuntu-latest
|
||||
permissions:
|
||||
contents: read
|
||||
packages: write
|
||||
steps:
|
||||
- uses: actions/checkout@v5
|
||||
|
||||
- name: Set up QEMU
|
||||
uses: docker/setup-qemu-action@v3
|
||||
- name: Download digests
|
||||
uses: actions/download-artifact@v7
|
||||
with:
|
||||
platforms: ${{ env.DOCKER_PLATFORMS }}
|
||||
pattern: digests-*
|
||||
merge-multiple: true
|
||||
path: /tmp/digests
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
id: buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
with:
|
||||
version: latest
|
||||
install: true
|
||||
platforms: ${{ env.DOCKER_PLATFORMS }}
|
||||
|
||||
- name: Docker meta
|
||||
id: meta
|
||||
|
|
@ -60,32 +152,15 @@ jobs:
|
|||
username: ${{ github.actor }}
|
||||
password: ${{ github.token }}
|
||||
|
||||
# Frontend build.
|
||||
- uses: pnpm/action-setup@v4.1.0
|
||||
with:
|
||||
version: 10
|
||||
- uses: actions/setup-node@v5
|
||||
with:
|
||||
node-version: "22"
|
||||
cache: pnpm
|
||||
cache-dependency-path: "web/pnpm-lock.yaml"
|
||||
- run: pnpm install
|
||||
working-directory: web
|
||||
- name: Run frontend build
|
||||
run: pnpm release
|
||||
working-directory: web
|
||||
- name: Create manifest list and push
|
||||
working-directory: /tmp/digests
|
||||
run: |
|
||||
docker buildx imagetools create $(jq -cr '.tags | map("-t " + .) | join(" ")' <<< "$DOCKER_METADATA_OUTPUT_JSON") \
|
||||
$(printf 'neosmemo/memos@sha256:%s ' *)
|
||||
env:
|
||||
DOCKER_METADATA_OUTPUT_JSON: ${{ steps.meta.outputs.json }}
|
||||
|
||||
- name: Build and Push
|
||||
id: docker_build
|
||||
uses: docker/build-push-action@v6
|
||||
with:
|
||||
context: .
|
||||
file: ./scripts/Dockerfile
|
||||
platforms: ${{ env.DOCKER_PLATFORMS }}
|
||||
push: true
|
||||
tags: ${{ steps.meta.outputs.tags }}
|
||||
labels: ${{ steps.meta.outputs.labels }}
|
||||
cache-from: type=gha
|
||||
cache-to: type=gha,mode=max
|
||||
build-args: |
|
||||
BUILDKIT_INLINE_CACHE=1
|
||||
- name: Inspect images
|
||||
run: |
|
||||
docker buildx imagetools inspect neosmemo/memos:canary
|
||||
docker buildx imagetools inspect ghcr.io/usememos/memos:canary
|
||||
|
|
|
|||
|
|
@ -7,41 +7,84 @@ on:
|
|||
tags:
|
||||
- "v*.*.*"
|
||||
|
||||
env:
|
||||
DOCKER_PLATFORMS: |
|
||||
linux/amd64
|
||||
linux/arm/v7
|
||||
linux/arm64
|
||||
|
||||
jobs:
|
||||
build-and-push-image:
|
||||
prepare:
|
||||
runs-on: ubuntu-latest
|
||||
outputs:
|
||||
version: ${{ steps.version.outputs.version }}
|
||||
steps:
|
||||
- name: Extract version
|
||||
id: version
|
||||
run: |
|
||||
if [[ "$GITHUB_REF_TYPE" == "tag" ]]; then
|
||||
echo "version=${GITHUB_REF_NAME#v}" >> $GITHUB_OUTPUT
|
||||
else
|
||||
echo "version=${GITHUB_REF_NAME#release/}" >> $GITHUB_OUTPUT
|
||||
fi
|
||||
|
||||
build-frontend:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v6
|
||||
|
||||
- uses: pnpm/action-setup@v4.2.0
|
||||
with:
|
||||
version: 10
|
||||
- uses: actions/setup-node@v6
|
||||
with:
|
||||
node-version: "22"
|
||||
cache: pnpm
|
||||
cache-dependency-path: "web/pnpm-lock.yaml"
|
||||
- name: Get pnpm store directory
|
||||
id: pnpm-cache
|
||||
shell: bash
|
||||
run: echo "STORE_PATH=$(pnpm store path)" >> $GITHUB_OUTPUT
|
||||
- name: Setup pnpm cache
|
||||
uses: actions/cache@v5
|
||||
with:
|
||||
path: ${{ steps.pnpm-cache.outputs.STORE_PATH }}
|
||||
key: ${{ runner.os }}-pnpm-store-${{ hashFiles('web/pnpm-lock.yaml') }}
|
||||
restore-keys: ${{ runner.os }}-pnpm-store-
|
||||
- run: pnpm install --frozen-lockfile
|
||||
working-directory: web
|
||||
- name: Run frontend build
|
||||
run: pnpm release
|
||||
working-directory: web
|
||||
|
||||
- name: Upload frontend artifacts
|
||||
uses: actions/upload-artifact@v6
|
||||
with:
|
||||
name: frontend-dist
|
||||
path: server/router/frontend/dist
|
||||
retention-days: 1
|
||||
|
||||
build-push:
|
||||
needs: [prepare, build-frontend]
|
||||
runs-on: ubuntu-latest
|
||||
permissions:
|
||||
contents: read
|
||||
packages: write
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
platform:
|
||||
- linux/amd64
|
||||
- linux/arm/v7
|
||||
- linux/arm64
|
||||
steps:
|
||||
- uses: actions/checkout@v5
|
||||
- uses: actions/checkout@v6
|
||||
|
||||
- name: Download frontend artifacts
|
||||
uses: actions/download-artifact@v7
|
||||
with:
|
||||
name: frontend-dist
|
||||
path: server/router/frontend/dist
|
||||
|
||||
- name: Set up QEMU
|
||||
uses: docker/setup-qemu-action@v3
|
||||
with:
|
||||
platforms: ${{ env.DOCKER_PLATFORMS }}
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
id: buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
with:
|
||||
version: latest
|
||||
install: true
|
||||
platforms: ${{ env.DOCKER_PLATFORMS }}
|
||||
|
||||
- name: Extract version
|
||||
run: |
|
||||
if [[ "$GITHUB_REF_TYPE" == "tag" ]]; then
|
||||
echo "VERSION=${GITHUB_REF_NAME#v}" >> $GITHUB_ENV
|
||||
else
|
||||
echo "VERSION=${GITHUB_REF_NAME#release/}" >> $GITHUB_ENV
|
||||
fi
|
||||
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@v3
|
||||
|
|
@ -56,6 +99,48 @@ jobs:
|
|||
username: ${{ github.actor }}
|
||||
password: ${{ github.token }}
|
||||
|
||||
- name: Build and push by digest
|
||||
id: build
|
||||
uses: docker/build-push-action@v6
|
||||
with:
|
||||
context: .
|
||||
file: ./scripts/Dockerfile
|
||||
platforms: ${{ matrix.platform }}
|
||||
cache-from: type=gha,scope=build-${{ matrix.platform }}
|
||||
cache-to: type=gha,mode=max,scope=build-${{ matrix.platform }}
|
||||
outputs: type=image,name=neosmemo/memos,push-by-digest=true,name-canonical=true,push=true
|
||||
|
||||
- name: Export digest
|
||||
run: |
|
||||
mkdir -p /tmp/digests
|
||||
digest="${{ steps.build.outputs.digest }}"
|
||||
touch "/tmp/digests/${digest#sha256:}"
|
||||
|
||||
- name: Upload digest
|
||||
uses: actions/upload-artifact@v6
|
||||
with:
|
||||
name: digests-${{ strategy.job-index }}
|
||||
path: /tmp/digests/*
|
||||
if-no-files-found: error
|
||||
retention-days: 1
|
||||
|
||||
merge:
|
||||
needs: [prepare, build-push]
|
||||
runs-on: ubuntu-latest
|
||||
permissions:
|
||||
contents: read
|
||||
packages: write
|
||||
steps:
|
||||
- name: Download digests
|
||||
uses: actions/download-artifact@v7
|
||||
with:
|
||||
pattern: digests-*
|
||||
merge-multiple: true
|
||||
path: /tmp/digests
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
|
||||
- name: Docker meta
|
||||
id: meta
|
||||
uses: docker/metadata-action@v5
|
||||
|
|
@ -64,40 +149,36 @@ jobs:
|
|||
neosmemo/memos
|
||||
ghcr.io/usememos/memos
|
||||
tags: |
|
||||
type=semver,pattern={{version}},value=${{ env.VERSION }}
|
||||
type=semver,pattern={{major}}.{{minor}},value=${{ env.VERSION }}
|
||||
type=semver,pattern={{version}},value=${{ needs.prepare.outputs.version }}
|
||||
type=semver,pattern={{major}}.{{minor}},value=${{ needs.prepare.outputs.version }}
|
||||
type=raw,value=stable
|
||||
flavor: |
|
||||
latest=false
|
||||
labels: |
|
||||
org.opencontainers.image.version=${{ env.VERSION }}
|
||||
org.opencontainers.image.version=${{ needs.prepare.outputs.version }}
|
||||
|
||||
# Frontend build.
|
||||
- uses: pnpm/action-setup@v4.1.0
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@v3
|
||||
with:
|
||||
version: 10
|
||||
- uses: actions/setup-node@v5
|
||||
with:
|
||||
node-version: "22"
|
||||
cache: pnpm
|
||||
cache-dependency-path: "web/pnpm-lock.yaml"
|
||||
- run: pnpm install
|
||||
working-directory: web
|
||||
- name: Run frontend build
|
||||
run: pnpm release
|
||||
working-directory: web
|
||||
username: ${{ secrets.DOCKER_HUB_USERNAME }}
|
||||
password: ${{ secrets.DOCKER_HUB_TOKEN }}
|
||||
|
||||
- name: Build and Push
|
||||
id: docker_build
|
||||
uses: docker/build-push-action@v6
|
||||
- name: Login to GitHub Container Registry
|
||||
uses: docker/login-action@v3
|
||||
with:
|
||||
context: .
|
||||
file: ./scripts/Dockerfile
|
||||
platforms: ${{ env.DOCKER_PLATFORMS }}
|
||||
push: true
|
||||
tags: ${{ steps.meta.outputs.tags }}
|
||||
labels: ${{ steps.meta.outputs.labels }}
|
||||
cache-from: type=gha
|
||||
cache-to: type=gha,mode=max
|
||||
build-args: |
|
||||
BUILDKIT_INLINE_CACHE=1
|
||||
registry: ghcr.io
|
||||
username: ${{ github.actor }}
|
||||
password: ${{ github.token }}
|
||||
|
||||
- name: Create manifest list and push
|
||||
working-directory: /tmp/digests
|
||||
run: |
|
||||
docker buildx imagetools create $(jq -cr '.tags | map("-t " + .) | join(" ")' <<< "$DOCKER_METADATA_OUTPUT_JSON") \
|
||||
$(printf 'neosmemo/memos@sha256:%s ' *)
|
||||
env:
|
||||
DOCKER_METADATA_OUTPUT_JSON: ${{ steps.meta.outputs.json }}
|
||||
|
||||
- name: Inspect images
|
||||
run: |
|
||||
docker buildx imagetools inspect neosmemo/memos:stable
|
||||
docker buildx imagetools inspect ghcr.io/usememos/memos:stable
|
||||
|
|
|
|||
|
|
@ -13,11 +13,11 @@ jobs:
|
|||
static-checks:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v5
|
||||
- uses: pnpm/action-setup@v4.1.0
|
||||
- uses: actions/checkout@v6
|
||||
- uses: pnpm/action-setup@v4.2.0
|
||||
with:
|
||||
version: 9
|
||||
- uses: actions/setup-node@v5
|
||||
- uses: actions/setup-node@v6
|
||||
with:
|
||||
node-version: "20"
|
||||
cache: pnpm
|
||||
|
|
@ -31,11 +31,11 @@ jobs:
|
|||
frontend-build:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v5
|
||||
- uses: pnpm/action-setup@v4.1.0
|
||||
- uses: actions/checkout@v6
|
||||
- uses: pnpm/action-setup@v4.2.0
|
||||
with:
|
||||
version: 9
|
||||
- uses: actions/setup-node@v5
|
||||
- uses: actions/setup-node@v6
|
||||
with:
|
||||
node-version: "20"
|
||||
cache: pnpm
|
||||
|
|
|
|||
|
|
@ -14,7 +14,7 @@ jobs:
|
|||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v5
|
||||
uses: actions/checkout@v6
|
||||
with:
|
||||
fetch-depth: 0
|
||||
- name: Setup buf
|
||||
|
|
|
|||
|
|
@ -11,7 +11,7 @@ jobs:
|
|||
issues: write
|
||||
|
||||
steps:
|
||||
- uses: actions/stale@v10.0.0
|
||||
- uses: actions/stale@v10.1.1
|
||||
with:
|
||||
days-before-issue-stale: 14
|
||||
days-before-issue-close: 7
|
||||
|
|
|
|||
15
AGENTS.md
15
AGENTS.md
|
|
@ -165,7 +165,7 @@ type Driver interface {
|
|||
4. Demo mode: Seed with demo data
|
||||
|
||||
**Schema Versioning:**
|
||||
- Stored in `instance_setting` table (key: `bb.general.version`)
|
||||
- Stored in `system_setting` table
|
||||
- Format: `major.minor.patch`
|
||||
- Migration files: `store/migration/{driver}/{version}/NN__description.sql`
|
||||
- See: `store/migrator.go:21-414`
|
||||
|
|
@ -191,7 +191,7 @@ cd proto && buf generate
|
|||
|
||||
```bash
|
||||
# Start dev server
|
||||
go run ./cmd/memos --mode dev --port 8081
|
||||
go run ./cmd/memos --port 8081
|
||||
|
||||
# Run all tests
|
||||
go test ./...
|
||||
|
|
@ -458,7 +458,7 @@ cd web && pnpm lint
|
|||
|
||||
| Variable | Default | Description |
|
||||
|----------|----------|-------------|
|
||||
| `MEMOS_MODE` | `dev` | Mode: `dev`, `prod`, `demo` |
|
||||
| `MEMOS_DEMO` | `false` | Enable demo mode |
|
||||
| `MEMOS_PORT` | `8081` | HTTP port |
|
||||
| `MEMOS_ADDR` | `` | Bind address (empty = all) |
|
||||
| `MEMOS_DATA` | `~/.memos` | Data directory |
|
||||
|
|
@ -503,13 +503,6 @@ cd web && pnpm lint
|
|||
|
||||
## Common Tasks
|
||||
|
||||
### Debugging Database Issues
|
||||
|
||||
1. Check connection string in logs
|
||||
2. Verify `store/db/{driver}/migration/` files exist
|
||||
3. Check schema version: `SELECT * FROM instance_setting WHERE key = 'bb.general.version'`
|
||||
4. Test migration: `go test ./store/test/... -v`
|
||||
|
||||
### Debugging API Issues
|
||||
|
||||
1. Check Connect interceptor logs: `server/router/api/v1/connect_interceptors.go:79-105`
|
||||
|
|
@ -571,7 +564,7 @@ Each plugin has its own README with usage examples.
|
|||
|
||||
## Security Notes
|
||||
|
||||
- JWT secrets must be kept secret (`MEMOS_MODE=prod` generates random secret)
|
||||
- JWT secrets must be kept secret (generated on first run in production mode)
|
||||
- Personal Access Tokens stored as SHA-256 hashes in database
|
||||
- CSRF protection via SameSite cookies
|
||||
- CORS enabled for all origins (configure for production)
|
||||
|
|
|
|||
|
|
@ -22,10 +22,10 @@ An open-source, self-hosted note-taking service. Your thoughts, your data, your
|
|||
|
||||
---
|
||||
|
||||
[**LambdaTest** - Cross-browser testing cloud](https://www.lambdatest.com/?utm_source=memos&utm_medium=sponsor)
|
||||
[**TestMu AI** - The world’s first full-stack Agentic AI Quality Engineering platform](https://www.testmu.ai/?utm_source=memos&utm_medium=sponsor)
|
||||
|
||||
<a href="https://www.lambdatest.com/?utm_source=memos&utm_medium=sponsor" target="_blank" rel="noopener">
|
||||
<img src="https://www.lambdatest.com/blue-logo.png" alt="LambdaTest - Cross-browser testing cloud" height="50" />
|
||||
<a href="https://www.testmu.ai/?utm_source=memos&utm_medium=sponsor" target="_blank" rel="noopener">
|
||||
<img src="https://usememos.com/sponsors/testmu.svg" alt="TestMu AI" height="36" />
|
||||
</a>
|
||||
|
||||
## Overview
|
||||
|
|
|
|||
|
|
@ -25,7 +25,7 @@ var (
|
|||
Short: `An open source, lightweight note-taking service. Easily capture and share your great thoughts.`,
|
||||
Run: func(_ *cobra.Command, _ []string) {
|
||||
instanceProfile := &profile.Profile{
|
||||
Mode: viper.GetString("mode"),
|
||||
Demo: viper.GetBool("demo"),
|
||||
Addr: viper.GetString("addr"),
|
||||
Port: viper.GetInt("port"),
|
||||
UNIXSock: viper.GetString("unix-sock"),
|
||||
|
|
@ -33,10 +33,12 @@ var (
|
|||
Driver: viper.GetString("driver"),
|
||||
DSN: viper.GetString("dsn"),
|
||||
InstanceURL: viper.GetString("instance-url"),
|
||||
Version: version.GetCurrentVersion(viper.GetString("mode")),
|
||||
}
|
||||
instanceProfile.Version = version.GetCurrentVersion()
|
||||
|
||||
if err := instanceProfile.Validate(); err != nil {
|
||||
panic(err)
|
||||
slog.Error("failed to validate profile", "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
|
|
@ -71,6 +73,7 @@ var (
|
|||
if err != http.ErrServerClosed {
|
||||
slog.Error("failed to start server", "error", err)
|
||||
cancel()
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -89,11 +92,11 @@ var (
|
|||
)
|
||||
|
||||
func init() {
|
||||
viper.SetDefault("mode", "dev")
|
||||
viper.SetDefault("demo", false)
|
||||
viper.SetDefault("driver", "sqlite")
|
||||
viper.SetDefault("port", 8081)
|
||||
|
||||
rootCmd.PersistentFlags().String("mode", "dev", `mode of server, can be "prod" or "dev" or "demo"`)
|
||||
rootCmd.PersistentFlags().Bool("demo", false, "enable demo mode")
|
||||
rootCmd.PersistentFlags().String("addr", "", "address of server")
|
||||
rootCmd.PersistentFlags().Int("port", 8081, "port of server")
|
||||
rootCmd.PersistentFlags().String("unix-sock", "", "path to the unix socket, overrides --addr and --port")
|
||||
|
|
@ -102,7 +105,7 @@ func init() {
|
|||
rootCmd.PersistentFlags().String("dsn", "", "database source name(aka. DSN)")
|
||||
rootCmd.PersistentFlags().String("instance-url", "", "the url of your memos instance")
|
||||
|
||||
if err := viper.BindPFlag("mode", rootCmd.PersistentFlags().Lookup("mode")); err != nil {
|
||||
if err := viper.BindPFlag("demo", rootCmd.PersistentFlags().Lookup("demo")); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
if err := viper.BindPFlag("addr", rootCmd.PersistentFlags().Lookup("addr")); err != nil {
|
||||
|
|
@ -129,15 +132,12 @@ func init() {
|
|||
|
||||
viper.SetEnvPrefix("memos")
|
||||
viper.AutomaticEnv()
|
||||
if err := viper.BindEnv("instance-url", "MEMOS_INSTANCE_URL"); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
|
||||
func printGreetings(profile *profile.Profile) {
|
||||
fmt.Printf("Memos %s started successfully!\n", profile.Version)
|
||||
|
||||
if profile.IsDev() {
|
||||
if profile.Demo {
|
||||
fmt.Fprint(os.Stderr, "Development mode is enabled\n")
|
||||
if profile.DSN != "" {
|
||||
fmt.Fprintf(os.Stderr, "Database: %s\n", profile.DSN)
|
||||
|
|
@ -147,7 +147,6 @@ func printGreetings(profile *profile.Profile) {
|
|||
// Server information
|
||||
fmt.Printf("Data directory: %s\n", profile.Data)
|
||||
fmt.Printf("Database driver: %s\n", profile.Driver)
|
||||
fmt.Printf("Mode: %s\n", profile.Mode)
|
||||
|
||||
// Connection information
|
||||
if len(profile.UNIXSock) == 0 {
|
||||
|
|
@ -170,6 +169,6 @@ func printGreetings(profile *profile.Profile) {
|
|||
|
||||
func main() {
|
||||
if err := rootCmd.Execute(); err != nil {
|
||||
panic(err)
|
||||
os.Exit(1)
|
||||
}
|
||||
}
|
||||
|
|
|
|||
2
go.mod
2
go.mod
|
|
@ -9,6 +9,7 @@ require (
|
|||
github.com/aws/aws-sdk-go-v2/credentials v1.18.16
|
||||
github.com/aws/aws-sdk-go-v2/feature/s3/manager v1.19.4
|
||||
github.com/aws/aws-sdk-go-v2/service/s3 v1.87.3
|
||||
github.com/docker/docker v28.5.1+incompatible
|
||||
github.com/go-sql-driver/mysql v1.9.3
|
||||
github.com/google/cel-go v0.26.1
|
||||
github.com/google/uuid v1.6.0
|
||||
|
|
@ -50,7 +51,6 @@ require (
|
|||
github.com/containerd/platforms v0.2.1 // indirect
|
||||
github.com/cpuguy83/dockercfg v0.3.2 // indirect
|
||||
github.com/distribution/reference v0.6.0 // indirect
|
||||
github.com/docker/docker v28.5.1+incompatible // indirect
|
||||
github.com/docker/go-connections v0.6.0 // indirect
|
||||
github.com/docker/go-units v0.5.0 // indirect
|
||||
github.com/dustin/go-humanize v1.0.1 // indirect
|
||||
|
|
|
|||
|
|
@ -13,8 +13,8 @@ import (
|
|||
|
||||
// Profile is the configuration to start main server.
|
||||
type Profile struct {
|
||||
// Mode can be "prod" or "dev" or "demo"
|
||||
Mode string
|
||||
// Demo indicates if the server is in demo mode
|
||||
Demo bool
|
||||
// Addr is the binding address for server
|
||||
Addr string
|
||||
// Port is the binding port for server
|
||||
|
|
@ -34,15 +34,12 @@ type Profile struct {
|
|||
InstanceURL string
|
||||
}
|
||||
|
||||
func (p *Profile) IsDev() bool {
|
||||
return p.Mode != "prod"
|
||||
}
|
||||
|
||||
func checkDataDir(dataDir string) (string, error) {
|
||||
// Convert to absolute path if relative path is supplied.
|
||||
if !filepath.IsAbs(dataDir) {
|
||||
relativeDir := filepath.Join(filepath.Dir(os.Args[0]), dataDir)
|
||||
absDir, err := filepath.Abs(relativeDir)
|
||||
// Use current working directory, not the binary's directory
|
||||
// This ensures we use the actual working directory where the process runs
|
||||
absDir, err := filepath.Abs(dataDir)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
|
@ -58,21 +55,35 @@ func checkDataDir(dataDir string) (string, error) {
|
|||
}
|
||||
|
||||
func (p *Profile) Validate() error {
|
||||
if p.Mode != "demo" && p.Mode != "dev" && p.Mode != "prod" {
|
||||
p.Mode = "demo"
|
||||
}
|
||||
|
||||
if p.Mode == "prod" && p.Data == "" {
|
||||
// Set default data directory if not specified
|
||||
if p.Data == "" {
|
||||
if runtime.GOOS == "windows" {
|
||||
p.Data = filepath.Join(os.Getenv("ProgramData"), "memos")
|
||||
if _, err := os.Stat(p.Data); os.IsNotExist(err) {
|
||||
if err := os.MkdirAll(p.Data, 0770); err != nil {
|
||||
slog.Error("failed to create data directory", slog.String("data", p.Data), slog.String("error", err.Error()))
|
||||
return err
|
||||
}
|
||||
}
|
||||
} else {
|
||||
p.Data = "/var/opt/memos"
|
||||
// On Linux/macOS, check if /var/opt/memos exists and is writable (Docker scenario)
|
||||
if info, err := os.Stat("/var/opt/memos"); err == nil && info.IsDir() {
|
||||
// Check if we can write to this directory
|
||||
testFile := filepath.Join("/var/opt/memos", ".write-test")
|
||||
if err := os.WriteFile(testFile, []byte("test"), 0600); err == nil {
|
||||
os.Remove(testFile)
|
||||
p.Data = "/var/opt/memos"
|
||||
} else {
|
||||
// /var/opt/memos exists but is not writable, use current directory
|
||||
slog.Warn("/var/opt/memos is not writable, using current directory")
|
||||
p.Data = "."
|
||||
}
|
||||
} else {
|
||||
// /var/opt/memos doesn't exist, use current directory (local development)
|
||||
p.Data = "."
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Create data directory if it doesn't exist
|
||||
if _, err := os.Stat(p.Data); os.IsNotExist(err) {
|
||||
if err := os.MkdirAll(p.Data, 0770); err != nil {
|
||||
slog.Error("failed to create data directory", slog.String("data", p.Data), slog.String("error", err.Error()))
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -84,7 +95,11 @@ func (p *Profile) Validate() error {
|
|||
|
||||
p.Data = dataDir
|
||||
if p.Driver == "sqlite" && p.DSN == "" {
|
||||
dbFile := fmt.Sprintf("memos_%s.db", p.Mode)
|
||||
mode := "prod"
|
||||
if p.Demo {
|
||||
mode = "demo"
|
||||
}
|
||||
dbFile := fmt.Sprintf("memos_%s.db", mode)
|
||||
p.DSN = filepath.Join(dataDir, dbFile)
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -9,15 +9,9 @@ import (
|
|||
|
||||
// Version is the service current released version.
|
||||
// Semantic versioning: https://semver.org/
|
||||
var Version = "0.25.3"
|
||||
var Version = "0.26.0"
|
||||
|
||||
// DevVersion is the service current development version.
|
||||
var DevVersion = "0.25.3"
|
||||
|
||||
func GetCurrentVersion(mode string) string {
|
||||
if mode == "dev" || mode == "demo" {
|
||||
return DevVersion
|
||||
}
|
||||
func GetCurrentVersion() string {
|
||||
return Version
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -1,11 +1,11 @@
|
|||
package email
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"golang.org/x/sync/errgroup"
|
||||
)
|
||||
|
||||
func TestSend(t *testing.T) {
|
||||
|
|
@ -106,34 +106,22 @@ func TestSendAsyncConcurrent(t *testing.T) {
|
|||
FromEmail: "test@example.com",
|
||||
}
|
||||
|
||||
// Send multiple emails concurrently
|
||||
var wg sync.WaitGroup
|
||||
g := errgroup.Group{}
|
||||
count := 5
|
||||
|
||||
for i := 0; i < count; i++ {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
g.Go(func() error {
|
||||
message := &Message{
|
||||
To: []string{"recipient@example.com"},
|
||||
Subject: "Concurrent Test",
|
||||
Body: "Test body",
|
||||
}
|
||||
SendAsync(config, message)
|
||||
}()
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
// Should complete without deadlock
|
||||
done := make(chan bool)
|
||||
go func() {
|
||||
wg.Wait()
|
||||
done <- true
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
// Success
|
||||
case <-time.After(1 * time.Second):
|
||||
t.Fatal("SendAsync calls did not complete in time")
|
||||
if err := g.Wait(); err != nil {
|
||||
t.Fatalf("SendAsync calls failed: %v", err)
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -114,3 +114,46 @@ type FunctionValue struct {
|
|||
}
|
||||
|
||||
func (*FunctionValue) isValueExpr() {}
|
||||
|
||||
// ListComprehensionCondition represents CEL macros like exists(), all(), filter().
|
||||
type ListComprehensionCondition struct {
|
||||
Kind ComprehensionKind
|
||||
Field string // The list field to iterate over (e.g., "tags")
|
||||
IterVar string // The iteration variable name (e.g., "t")
|
||||
Predicate PredicateExpr // The predicate to evaluate for each element
|
||||
}
|
||||
|
||||
func (*ListComprehensionCondition) isCondition() {}
|
||||
|
||||
// ComprehensionKind enumerates the types of list comprehensions.
|
||||
type ComprehensionKind string
|
||||
|
||||
const (
|
||||
ComprehensionExists ComprehensionKind = "exists"
|
||||
)
|
||||
|
||||
// PredicateExpr represents predicates used in comprehensions.
|
||||
type PredicateExpr interface {
|
||||
isPredicateExpr()
|
||||
}
|
||||
|
||||
// StartsWithPredicate represents t.startsWith("prefix").
|
||||
type StartsWithPredicate struct {
|
||||
Prefix string
|
||||
}
|
||||
|
||||
func (*StartsWithPredicate) isPredicateExpr() {}
|
||||
|
||||
// EndsWithPredicate represents t.endsWith("suffix").
|
||||
type EndsWithPredicate struct {
|
||||
Suffix string
|
||||
}
|
||||
|
||||
func (*EndsWithPredicate) isPredicateExpr() {}
|
||||
|
||||
// ContainsPredicate represents t.contains("substring").
|
||||
type ContainsPredicate struct {
|
||||
Substring string
|
||||
}
|
||||
|
||||
func (*ContainsPredicate) isPredicateExpr() {}
|
||||
|
|
|
|||
|
|
@ -36,6 +36,8 @@ func buildCondition(expr *exprv1.Expr, schema Schema) (Condition, error) {
|
|||
return nil, errors.Errorf("identifier %q is not boolean", name)
|
||||
}
|
||||
return &FieldPredicateCondition{Field: name}, nil
|
||||
case *exprv1.Expr_ComprehensionExpr:
|
||||
return buildComprehensionCondition(v.ComprehensionExpr, schema)
|
||||
default:
|
||||
return nil, errors.New("unsupported top-level expression")
|
||||
}
|
||||
|
|
@ -415,3 +417,170 @@ func evaluateNumeric(expr *exprv1.Expr) (int64, bool, error) {
|
|||
func timeNowUnix() int64 {
|
||||
return time.Now().Unix()
|
||||
}
|
||||
|
||||
// buildComprehensionCondition handles CEL comprehension expressions (exists, all, etc.).
|
||||
func buildComprehensionCondition(comp *exprv1.Expr_Comprehension, schema Schema) (Condition, error) {
|
||||
// Determine the comprehension kind by examining the loop initialization and step
|
||||
kind, err := detectComprehensionKind(comp)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Get the field being iterated over
|
||||
iterRangeIdent := comp.IterRange.GetIdentExpr()
|
||||
if iterRangeIdent == nil {
|
||||
return nil, errors.New("comprehension range must be a field identifier")
|
||||
}
|
||||
fieldName := iterRangeIdent.GetName()
|
||||
|
||||
// Validate the field
|
||||
field, ok := schema.Field(fieldName)
|
||||
if !ok {
|
||||
return nil, errors.Errorf("unknown field %q in comprehension", fieldName)
|
||||
}
|
||||
if field.Kind != FieldKindJSONList {
|
||||
return nil, errors.Errorf("field %q does not support comprehension (must be a list)", fieldName)
|
||||
}
|
||||
|
||||
// Extract the predicate from the loop step
|
||||
predicate, err := extractPredicate(comp, schema)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &ListComprehensionCondition{
|
||||
Kind: kind,
|
||||
Field: fieldName,
|
||||
IterVar: comp.IterVar,
|
||||
Predicate: predicate,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// detectComprehensionKind determines if this is an exists() macro.
|
||||
// Only exists() is currently supported.
|
||||
func detectComprehensionKind(comp *exprv1.Expr_Comprehension) (ComprehensionKind, error) {
|
||||
// Check the accumulator initialization
|
||||
accuInit := comp.AccuInit.GetConstExpr()
|
||||
if accuInit == nil {
|
||||
return "", errors.New("comprehension accumulator must be initialized with a constant")
|
||||
}
|
||||
|
||||
// exists() starts with false and uses OR (||) in loop step
|
||||
if !accuInit.GetBoolValue() {
|
||||
if step := comp.LoopStep.GetCallExpr(); step != nil && step.Function == "_||_" {
|
||||
return ComprehensionExists, nil
|
||||
}
|
||||
}
|
||||
|
||||
// all() starts with true and uses AND (&&) - not supported
|
||||
if accuInit.GetBoolValue() {
|
||||
if step := comp.LoopStep.GetCallExpr(); step != nil && step.Function == "_&&_" {
|
||||
return "", errors.New("all() comprehension is not supported; use exists() instead")
|
||||
}
|
||||
}
|
||||
|
||||
return "", errors.New("unsupported comprehension type; only exists() is supported")
|
||||
}
|
||||
|
||||
// extractPredicate extracts the predicate expression from the comprehension loop step.
|
||||
func extractPredicate(comp *exprv1.Expr_Comprehension, _ Schema) (PredicateExpr, error) {
|
||||
// The loop step is: @result || predicate(t) for exists
|
||||
// or: @result && predicate(t) for all
|
||||
step := comp.LoopStep.GetCallExpr()
|
||||
if step == nil {
|
||||
return nil, errors.New("comprehension loop step must be a call expression")
|
||||
}
|
||||
|
||||
if len(step.Args) != 2 {
|
||||
return nil, errors.New("comprehension loop step must have two arguments")
|
||||
}
|
||||
|
||||
// The predicate is the second argument
|
||||
predicateExpr := step.Args[1]
|
||||
predicateCall := predicateExpr.GetCallExpr()
|
||||
if predicateCall == nil {
|
||||
return nil, errors.New("comprehension predicate must be a function call")
|
||||
}
|
||||
|
||||
// Handle different predicate functions
|
||||
switch predicateCall.Function {
|
||||
case "startsWith":
|
||||
return buildStartsWithPredicate(predicateCall, comp.IterVar)
|
||||
case "endsWith":
|
||||
return buildEndsWithPredicate(predicateCall, comp.IterVar)
|
||||
case "contains":
|
||||
return buildContainsPredicate(predicateCall, comp.IterVar)
|
||||
default:
|
||||
return nil, errors.Errorf("unsupported predicate function %q in comprehension (supported: startsWith, endsWith, contains)", predicateCall.Function)
|
||||
}
|
||||
}
|
||||
|
||||
// buildStartsWithPredicate extracts the pattern from t.startsWith("prefix").
|
||||
func buildStartsWithPredicate(call *exprv1.Expr_Call, iterVar string) (PredicateExpr, error) {
|
||||
// Verify the target is the iteration variable
|
||||
if target := call.Target.GetIdentExpr(); target == nil || target.GetName() != iterVar {
|
||||
return nil, errors.Errorf("startsWith target must be the iteration variable %q", iterVar)
|
||||
}
|
||||
|
||||
if len(call.Args) != 1 {
|
||||
return nil, errors.New("startsWith expects exactly one argument")
|
||||
}
|
||||
|
||||
prefix, err := getConstValue(call.Args[0])
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "startsWith argument must be a constant string")
|
||||
}
|
||||
|
||||
prefixStr, ok := prefix.(string)
|
||||
if !ok {
|
||||
return nil, errors.New("startsWith argument must be a string")
|
||||
}
|
||||
|
||||
return &StartsWithPredicate{Prefix: prefixStr}, nil
|
||||
}
|
||||
|
||||
// buildEndsWithPredicate extracts the pattern from t.endsWith("suffix").
|
||||
func buildEndsWithPredicate(call *exprv1.Expr_Call, iterVar string) (PredicateExpr, error) {
|
||||
if target := call.Target.GetIdentExpr(); target == nil || target.GetName() != iterVar {
|
||||
return nil, errors.Errorf("endsWith target must be the iteration variable %q", iterVar)
|
||||
}
|
||||
|
||||
if len(call.Args) != 1 {
|
||||
return nil, errors.New("endsWith expects exactly one argument")
|
||||
}
|
||||
|
||||
suffix, err := getConstValue(call.Args[0])
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "endsWith argument must be a constant string")
|
||||
}
|
||||
|
||||
suffixStr, ok := suffix.(string)
|
||||
if !ok {
|
||||
return nil, errors.New("endsWith argument must be a string")
|
||||
}
|
||||
|
||||
return &EndsWithPredicate{Suffix: suffixStr}, nil
|
||||
}
|
||||
|
||||
// buildContainsPredicate extracts the pattern from t.contains("substring").
|
||||
func buildContainsPredicate(call *exprv1.Expr_Call, iterVar string) (PredicateExpr, error) {
|
||||
if target := call.Target.GetIdentExpr(); target == nil || target.GetName() != iterVar {
|
||||
return nil, errors.Errorf("contains target must be the iteration variable %q", iterVar)
|
||||
}
|
||||
|
||||
if len(call.Args) != 1 {
|
||||
return nil, errors.New("contains expects exactly one argument")
|
||||
}
|
||||
|
||||
substring, err := getConstValue(call.Args[0])
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "contains argument must be a constant string")
|
||||
}
|
||||
|
||||
substringStr, ok := substring.(string)
|
||||
if !ok {
|
||||
return nil, errors.New("contains argument must be a string")
|
||||
}
|
||||
|
||||
return &ContainsPredicate{Substring: substringStr}, nil
|
||||
}
|
||||
|
|
|
|||
|
|
@ -74,6 +74,8 @@ func (r *renderer) renderCondition(cond Condition) (renderResult, error) {
|
|||
return r.renderElementInCondition(c)
|
||||
case *ContainsCondition:
|
||||
return r.renderContainsCondition(c)
|
||||
case *ListComprehensionCondition:
|
||||
return r.renderListComprehension(c)
|
||||
case *ConstantCondition:
|
||||
if c.Value {
|
||||
return renderResult{trivial: true}, nil
|
||||
|
|
@ -461,13 +463,108 @@ func (r *renderer) renderContainsCondition(cond *ContainsCondition) (renderResul
|
|||
}
|
||||
}
|
||||
|
||||
func (r *renderer) renderListComprehension(cond *ListComprehensionCondition) (renderResult, error) {
|
||||
field, ok := r.schema.Field(cond.Field)
|
||||
if !ok {
|
||||
return renderResult{}, errors.Errorf("unknown field %q", cond.Field)
|
||||
}
|
||||
|
||||
if field.Kind != FieldKindJSONList {
|
||||
return renderResult{}, errors.Errorf("field %q is not a JSON list", cond.Field)
|
||||
}
|
||||
|
||||
// Render based on predicate type
|
||||
switch pred := cond.Predicate.(type) {
|
||||
case *StartsWithPredicate:
|
||||
return r.renderTagStartsWith(field, pred.Prefix, cond.Kind)
|
||||
case *EndsWithPredicate:
|
||||
return r.renderTagEndsWith(field, pred.Suffix, cond.Kind)
|
||||
case *ContainsPredicate:
|
||||
return r.renderTagContains(field, pred.Substring, cond.Kind)
|
||||
default:
|
||||
return renderResult{}, errors.Errorf("unsupported predicate type %T in comprehension", pred)
|
||||
}
|
||||
}
|
||||
|
||||
// renderTagStartsWith generates SQL for tags.exists(t, t.startsWith("prefix")).
|
||||
func (r *renderer) renderTagStartsWith(field Field, prefix string, _ ComprehensionKind) (renderResult, error) {
|
||||
arrayExpr := jsonArrayExpr(r.dialect, field)
|
||||
|
||||
switch r.dialect {
|
||||
case DialectSQLite, DialectMySQL:
|
||||
// Match exact tag or tags with this prefix (hierarchical support)
|
||||
exactMatch := r.buildJSONArrayLike(arrayExpr, fmt.Sprintf(`%%"%s"%%`, prefix))
|
||||
prefixMatch := r.buildJSONArrayLike(arrayExpr, fmt.Sprintf(`%%"%s%%`, prefix))
|
||||
condition := fmt.Sprintf("(%s OR %s)", exactMatch, prefixMatch)
|
||||
return renderResult{sql: r.wrapWithNullCheck(arrayExpr, condition)}, nil
|
||||
|
||||
case DialectPostgres:
|
||||
// Use PostgreSQL's powerful JSON operators
|
||||
exactMatch := fmt.Sprintf("%s @> jsonb_build_array(%s::json)", arrayExpr, r.addArg(fmt.Sprintf(`"%s"`, prefix)))
|
||||
prefixMatch := fmt.Sprintf("(%s)::text LIKE %s", arrayExpr, r.addArg(fmt.Sprintf(`%%"%s%%`, prefix)))
|
||||
condition := fmt.Sprintf("(%s OR %s)", exactMatch, prefixMatch)
|
||||
return renderResult{sql: r.wrapWithNullCheck(arrayExpr, condition)}, nil
|
||||
|
||||
default:
|
||||
return renderResult{}, errors.Errorf("unsupported dialect %s", r.dialect)
|
||||
}
|
||||
}
|
||||
|
||||
// renderTagEndsWith generates SQL for tags.exists(t, t.endsWith("suffix")).
|
||||
func (r *renderer) renderTagEndsWith(field Field, suffix string, _ ComprehensionKind) (renderResult, error) {
|
||||
arrayExpr := jsonArrayExpr(r.dialect, field)
|
||||
pattern := fmt.Sprintf(`%%%s"%%`, suffix)
|
||||
|
||||
likeExpr := r.buildJSONArrayLike(arrayExpr, pattern)
|
||||
return renderResult{sql: r.wrapWithNullCheck(arrayExpr, likeExpr)}, nil
|
||||
}
|
||||
|
||||
// renderTagContains generates SQL for tags.exists(t, t.contains("substring")).
|
||||
func (r *renderer) renderTagContains(field Field, substring string, _ ComprehensionKind) (renderResult, error) {
|
||||
arrayExpr := jsonArrayExpr(r.dialect, field)
|
||||
pattern := fmt.Sprintf(`%%%s%%`, substring)
|
||||
|
||||
likeExpr := r.buildJSONArrayLike(arrayExpr, pattern)
|
||||
return renderResult{sql: r.wrapWithNullCheck(arrayExpr, likeExpr)}, nil
|
||||
}
|
||||
|
||||
// buildJSONArrayLike builds a LIKE expression for matching within a JSON array.
|
||||
// Returns the LIKE clause without NULL/empty checks.
|
||||
func (r *renderer) buildJSONArrayLike(arrayExpr, pattern string) string {
|
||||
switch r.dialect {
|
||||
case DialectSQLite, DialectMySQL:
|
||||
return fmt.Sprintf("%s LIKE %s", arrayExpr, r.addArg(pattern))
|
||||
case DialectPostgres:
|
||||
return fmt.Sprintf("(%s)::text LIKE %s", arrayExpr, r.addArg(pattern))
|
||||
default:
|
||||
return ""
|
||||
}
|
||||
}
|
||||
|
||||
// wrapWithNullCheck wraps a condition with NULL and empty array checks.
|
||||
// This ensures we don't match against NULL or empty JSON arrays.
|
||||
func (r *renderer) wrapWithNullCheck(arrayExpr, condition string) string {
|
||||
var nullCheck string
|
||||
switch r.dialect {
|
||||
case DialectSQLite:
|
||||
nullCheck = fmt.Sprintf("%s IS NOT NULL AND %s != '[]'", arrayExpr, arrayExpr)
|
||||
case DialectMySQL:
|
||||
nullCheck = fmt.Sprintf("%s IS NOT NULL AND JSON_LENGTH(%s) > 0", arrayExpr, arrayExpr)
|
||||
case DialectPostgres:
|
||||
nullCheck = fmt.Sprintf("%s IS NOT NULL AND jsonb_array_length(%s) > 0", arrayExpr, arrayExpr)
|
||||
default:
|
||||
return condition
|
||||
}
|
||||
return fmt.Sprintf("(%s AND %s)", condition, nullCheck)
|
||||
}
|
||||
|
||||
func (r *renderer) jsonBoolPredicate(field Field) (string, error) {
|
||||
expr := jsonExtractExpr(r.dialect, field)
|
||||
switch r.dialect {
|
||||
case DialectSQLite:
|
||||
return fmt.Sprintf("%s IS TRUE", expr), nil
|
||||
case DialectMySQL:
|
||||
return fmt.Sprintf("%s = CAST('true' AS JSON)", expr), nil
|
||||
return fmt.Sprintf("COALESCE(%s, CAST('false' AS JSON)) = CAST('true' AS JSON)", expr), nil
|
||||
case DialectPostgres:
|
||||
return fmt.Sprintf("(%s)::boolean IS TRUE", expr), nil
|
||||
default:
|
||||
|
|
|
|||
|
|
@ -256,7 +256,7 @@ func NewAttachmentSchema() Schema {
|
|||
Name: "filename",
|
||||
Kind: FieldKindScalar,
|
||||
Type: FieldTypeString,
|
||||
Column: Column{Table: "resource", Name: "filename"},
|
||||
Column: Column{Table: "attachment", Name: "filename"},
|
||||
SupportsContains: true,
|
||||
Expressions: map[DialectName]string{},
|
||||
},
|
||||
|
|
@ -264,14 +264,14 @@ func NewAttachmentSchema() Schema {
|
|||
Name: "mime_type",
|
||||
Kind: FieldKindScalar,
|
||||
Type: FieldTypeString,
|
||||
Column: Column{Table: "resource", Name: "type"},
|
||||
Column: Column{Table: "attachment", Name: "type"},
|
||||
Expressions: map[DialectName]string{},
|
||||
},
|
||||
"create_time": {
|
||||
Name: "create_time",
|
||||
Kind: FieldKindScalar,
|
||||
Type: FieldTypeTimestamp,
|
||||
Column: Column{Table: "resource", Name: "created_ts"},
|
||||
Column: Column{Table: "attachment", Name: "created_ts"},
|
||||
Expressions: map[DialectName]string{
|
||||
// MySQL stores created_ts as TIMESTAMP, needs conversion to epoch
|
||||
DialectMySQL: "UNIX_TIMESTAMP(%s)",
|
||||
|
|
@ -284,7 +284,7 @@ func NewAttachmentSchema() Schema {
|
|||
Name: "memo_id",
|
||||
Kind: FieldKindScalar,
|
||||
Type: FieldTypeInt,
|
||||
Column: Column{Table: "resource", Name: "memo_id"},
|
||||
Column: Column{Table: "attachment", Name: "memo_id"},
|
||||
Expressions: map[DialectName]string{},
|
||||
AllowedComparisonOps: map[ComparisonOperator]bool{
|
||||
CompareEq: true,
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@ syntax = "proto3";
|
|||
|
||||
package memos.api.v1;
|
||||
|
||||
import "api/v1/user_service.proto";
|
||||
import "google/api/annotations.proto";
|
||||
import "google/api/client.proto";
|
||||
import "google/api/field_behavior.proto";
|
||||
|
|
@ -34,18 +35,18 @@ service InstanceService {
|
|||
|
||||
// Instance profile message containing basic instance information.
|
||||
message InstanceProfile {
|
||||
// The name of instance owner.
|
||||
// Format: users/{user}
|
||||
string owner = 1;
|
||||
|
||||
// Version is the current version of instance.
|
||||
string version = 2;
|
||||
|
||||
// Mode is the instance mode (e.g. "prod", "dev" or "demo").
|
||||
string mode = 3;
|
||||
// Demo indicates if the instance is in demo mode.
|
||||
bool demo = 3;
|
||||
|
||||
// Instance URL is the URL of the instance.
|
||||
string instance_url = 6;
|
||||
|
||||
// The first administrator who set up this instance.
|
||||
// When null, instance requires initial setup (creating the first admin account).
|
||||
User admin = 7;
|
||||
}
|
||||
|
||||
// Request for instance profile.
|
||||
|
|
|
|||
|
|
@ -173,11 +173,13 @@ message Memo {
|
|||
(google.api.resource_reference) = {type: "memos.api.v1/User"}
|
||||
];
|
||||
|
||||
// Output only. The creation timestamp.
|
||||
google.protobuf.Timestamp create_time = 4 [(google.api.field_behavior) = OUTPUT_ONLY];
|
||||
// The creation timestamp.
|
||||
// If not set on creation, the server will set it to the current time.
|
||||
google.protobuf.Timestamp create_time = 4 [(google.api.field_behavior) = OPTIONAL];
|
||||
|
||||
// Output only. The last update timestamp.
|
||||
google.protobuf.Timestamp update_time = 5 [(google.api.field_behavior) = OUTPUT_ONLY];
|
||||
// The last update timestamp.
|
||||
// If not set on creation, the server will set it to the current time.
|
||||
google.protobuf.Timestamp update_time = 5 [(google.api.field_behavior) = OPTIONAL];
|
||||
|
||||
// The display timestamp of the memo.
|
||||
google.protobuf.Timestamp display_time = 6 [(google.api.field_behavior) = OPTIONAL];
|
||||
|
|
|
|||
|
|
@ -203,13 +203,10 @@ message User {
|
|||
|
||||
// User role enumeration.
|
||||
enum Role {
|
||||
// Unspecified role.
|
||||
ROLE_UNSPECIFIED = 0;
|
||||
// Host role with full system access.
|
||||
HOST = 1;
|
||||
// Admin role with administrative privileges.
|
||||
// Admin role with system access.
|
||||
ADMIN = 2;
|
||||
// Regular user role.
|
||||
// User role with limited access.
|
||||
USER = 3;
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -138,15 +138,15 @@ func (InstanceSetting_StorageSetting_StorageType) EnumDescriptor() ([]byte, []in
|
|||
// Instance profile message containing basic instance information.
|
||||
type InstanceProfile struct {
|
||||
state protoimpl.MessageState `protogen:"open.v1"`
|
||||
// The name of instance owner.
|
||||
// Format: users/{user}
|
||||
Owner string `protobuf:"bytes,1,opt,name=owner,proto3" json:"owner,omitempty"`
|
||||
// Version is the current version of instance.
|
||||
Version string `protobuf:"bytes,2,opt,name=version,proto3" json:"version,omitempty"`
|
||||
// Mode is the instance mode (e.g. "prod", "dev" or "demo").
|
||||
Mode string `protobuf:"bytes,3,opt,name=mode,proto3" json:"mode,omitempty"`
|
||||
// Demo indicates if the instance is in demo mode.
|
||||
Demo bool `protobuf:"varint,3,opt,name=demo,proto3" json:"demo,omitempty"`
|
||||
// Instance URL is the URL of the instance.
|
||||
InstanceUrl string `protobuf:"bytes,6,opt,name=instance_url,json=instanceUrl,proto3" json:"instance_url,omitempty"`
|
||||
InstanceUrl string `protobuf:"bytes,6,opt,name=instance_url,json=instanceUrl,proto3" json:"instance_url,omitempty"`
|
||||
// The first administrator who set up this instance.
|
||||
// When null, instance requires initial setup (creating the first admin account).
|
||||
Admin *User `protobuf:"bytes,7,opt,name=admin,proto3" json:"admin,omitempty"`
|
||||
unknownFields protoimpl.UnknownFields
|
||||
sizeCache protoimpl.SizeCache
|
||||
}
|
||||
|
|
@ -181,13 +181,6 @@ func (*InstanceProfile) Descriptor() ([]byte, []int) {
|
|||
return file_api_v1_instance_service_proto_rawDescGZIP(), []int{0}
|
||||
}
|
||||
|
||||
func (x *InstanceProfile) GetOwner() string {
|
||||
if x != nil {
|
||||
return x.Owner
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func (x *InstanceProfile) GetVersion() string {
|
||||
if x != nil {
|
||||
return x.Version
|
||||
|
|
@ -195,11 +188,11 @@ func (x *InstanceProfile) GetVersion() string {
|
|||
return ""
|
||||
}
|
||||
|
||||
func (x *InstanceProfile) GetMode() string {
|
||||
func (x *InstanceProfile) GetDemo() bool {
|
||||
if x != nil {
|
||||
return x.Mode
|
||||
return x.Demo
|
||||
}
|
||||
return ""
|
||||
return false
|
||||
}
|
||||
|
||||
func (x *InstanceProfile) GetInstanceUrl() string {
|
||||
|
|
@ -209,6 +202,13 @@ func (x *InstanceProfile) GetInstanceUrl() string {
|
|||
return ""
|
||||
}
|
||||
|
||||
func (x *InstanceProfile) GetAdmin() *User {
|
||||
if x != nil {
|
||||
return x.Admin
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Request for instance profile.
|
||||
type GetInstanceProfileRequest struct {
|
||||
state protoimpl.MessageState `protogen:"open.v1"`
|
||||
|
|
@ -875,12 +875,12 @@ var File_api_v1_instance_service_proto protoreflect.FileDescriptor
|
|||
|
||||
const file_api_v1_instance_service_proto_rawDesc = "" +
|
||||
"\n" +
|
||||
"\x1dapi/v1/instance_service.proto\x12\fmemos.api.v1\x1a\x1cgoogle/api/annotations.proto\x1a\x17google/api/client.proto\x1a\x1fgoogle/api/field_behavior.proto\x1a\x19google/api/resource.proto\x1a google/protobuf/field_mask.proto\"x\n" +
|
||||
"\x0fInstanceProfile\x12\x14\n" +
|
||||
"\x05owner\x18\x01 \x01(\tR\x05owner\x12\x18\n" +
|
||||
"\x1dapi/v1/instance_service.proto\x12\fmemos.api.v1\x1a\x19api/v1/user_service.proto\x1a\x1cgoogle/api/annotations.proto\x1a\x17google/api/client.proto\x1a\x1fgoogle/api/field_behavior.proto\x1a\x19google/api/resource.proto\x1a google/protobuf/field_mask.proto\"\x8c\x01\n" +
|
||||
"\x0fInstanceProfile\x12\x18\n" +
|
||||
"\aversion\x18\x02 \x01(\tR\aversion\x12\x12\n" +
|
||||
"\x04mode\x18\x03 \x01(\tR\x04mode\x12!\n" +
|
||||
"\finstance_url\x18\x06 \x01(\tR\vinstanceUrl\"\x1b\n" +
|
||||
"\x04demo\x18\x03 \x01(\bR\x04demo\x12!\n" +
|
||||
"\finstance_url\x18\x06 \x01(\tR\vinstanceUrl\x12(\n" +
|
||||
"\x05admin\x18\a \x01(\v2\x12.memos.api.v1.UserR\x05admin\"\x1b\n" +
|
||||
"\x19GetInstanceProfileRequest\"\x99\x0f\n" +
|
||||
"\x0fInstanceSetting\x12\x17\n" +
|
||||
"\x04name\x18\x01 \x01(\tB\x03\xe0A\bR\x04name\x12W\n" +
|
||||
|
|
@ -970,28 +970,30 @@ var file_api_v1_instance_service_proto_goTypes = []any{
|
|||
(*InstanceSetting_MemoRelatedSetting)(nil), // 9: memos.api.v1.InstanceSetting.MemoRelatedSetting
|
||||
(*InstanceSetting_GeneralSetting_CustomProfile)(nil), // 10: memos.api.v1.InstanceSetting.GeneralSetting.CustomProfile
|
||||
(*InstanceSetting_StorageSetting_S3Config)(nil), // 11: memos.api.v1.InstanceSetting.StorageSetting.S3Config
|
||||
(*fieldmaskpb.FieldMask)(nil), // 12: google.protobuf.FieldMask
|
||||
(*User)(nil), // 12: memos.api.v1.User
|
||||
(*fieldmaskpb.FieldMask)(nil), // 13: google.protobuf.FieldMask
|
||||
}
|
||||
var file_api_v1_instance_service_proto_depIdxs = []int32{
|
||||
7, // 0: memos.api.v1.InstanceSetting.general_setting:type_name -> memos.api.v1.InstanceSetting.GeneralSetting
|
||||
8, // 1: memos.api.v1.InstanceSetting.storage_setting:type_name -> memos.api.v1.InstanceSetting.StorageSetting
|
||||
9, // 2: memos.api.v1.InstanceSetting.memo_related_setting:type_name -> memos.api.v1.InstanceSetting.MemoRelatedSetting
|
||||
4, // 3: memos.api.v1.UpdateInstanceSettingRequest.setting:type_name -> memos.api.v1.InstanceSetting
|
||||
12, // 4: memos.api.v1.UpdateInstanceSettingRequest.update_mask:type_name -> google.protobuf.FieldMask
|
||||
10, // 5: memos.api.v1.InstanceSetting.GeneralSetting.custom_profile:type_name -> memos.api.v1.InstanceSetting.GeneralSetting.CustomProfile
|
||||
1, // 6: memos.api.v1.InstanceSetting.StorageSetting.storage_type:type_name -> memos.api.v1.InstanceSetting.StorageSetting.StorageType
|
||||
11, // 7: memos.api.v1.InstanceSetting.StorageSetting.s3_config:type_name -> memos.api.v1.InstanceSetting.StorageSetting.S3Config
|
||||
3, // 8: memos.api.v1.InstanceService.GetInstanceProfile:input_type -> memos.api.v1.GetInstanceProfileRequest
|
||||
5, // 9: memos.api.v1.InstanceService.GetInstanceSetting:input_type -> memos.api.v1.GetInstanceSettingRequest
|
||||
6, // 10: memos.api.v1.InstanceService.UpdateInstanceSetting:input_type -> memos.api.v1.UpdateInstanceSettingRequest
|
||||
2, // 11: memos.api.v1.InstanceService.GetInstanceProfile:output_type -> memos.api.v1.InstanceProfile
|
||||
4, // 12: memos.api.v1.InstanceService.GetInstanceSetting:output_type -> memos.api.v1.InstanceSetting
|
||||
4, // 13: memos.api.v1.InstanceService.UpdateInstanceSetting:output_type -> memos.api.v1.InstanceSetting
|
||||
11, // [11:14] is the sub-list for method output_type
|
||||
8, // [8:11] is the sub-list for method input_type
|
||||
8, // [8:8] is the sub-list for extension type_name
|
||||
8, // [8:8] is the sub-list for extension extendee
|
||||
0, // [0:8] is the sub-list for field type_name
|
||||
12, // 0: memos.api.v1.InstanceProfile.admin:type_name -> memos.api.v1.User
|
||||
7, // 1: memos.api.v1.InstanceSetting.general_setting:type_name -> memos.api.v1.InstanceSetting.GeneralSetting
|
||||
8, // 2: memos.api.v1.InstanceSetting.storage_setting:type_name -> memos.api.v1.InstanceSetting.StorageSetting
|
||||
9, // 3: memos.api.v1.InstanceSetting.memo_related_setting:type_name -> memos.api.v1.InstanceSetting.MemoRelatedSetting
|
||||
4, // 4: memos.api.v1.UpdateInstanceSettingRequest.setting:type_name -> memos.api.v1.InstanceSetting
|
||||
13, // 5: memos.api.v1.UpdateInstanceSettingRequest.update_mask:type_name -> google.protobuf.FieldMask
|
||||
10, // 6: memos.api.v1.InstanceSetting.GeneralSetting.custom_profile:type_name -> memos.api.v1.InstanceSetting.GeneralSetting.CustomProfile
|
||||
1, // 7: memos.api.v1.InstanceSetting.StorageSetting.storage_type:type_name -> memos.api.v1.InstanceSetting.StorageSetting.StorageType
|
||||
11, // 8: memos.api.v1.InstanceSetting.StorageSetting.s3_config:type_name -> memos.api.v1.InstanceSetting.StorageSetting.S3Config
|
||||
3, // 9: memos.api.v1.InstanceService.GetInstanceProfile:input_type -> memos.api.v1.GetInstanceProfileRequest
|
||||
5, // 10: memos.api.v1.InstanceService.GetInstanceSetting:input_type -> memos.api.v1.GetInstanceSettingRequest
|
||||
6, // 11: memos.api.v1.InstanceService.UpdateInstanceSetting:input_type -> memos.api.v1.UpdateInstanceSettingRequest
|
||||
2, // 12: memos.api.v1.InstanceService.GetInstanceProfile:output_type -> memos.api.v1.InstanceProfile
|
||||
4, // 13: memos.api.v1.InstanceService.GetInstanceSetting:output_type -> memos.api.v1.InstanceSetting
|
||||
4, // 14: memos.api.v1.InstanceService.UpdateInstanceSetting:output_type -> memos.api.v1.InstanceSetting
|
||||
12, // [12:15] is the sub-list for method output_type
|
||||
9, // [9:12] is the sub-list for method input_type
|
||||
9, // [9:9] is the sub-list for extension type_name
|
||||
9, // [9:9] is the sub-list for extension extendee
|
||||
0, // [0:9] is the sub-list for field type_name
|
||||
}
|
||||
|
||||
func init() { file_api_v1_instance_service_proto_init() }
|
||||
|
|
@ -999,6 +1001,7 @@ func file_api_v1_instance_service_proto_init() {
|
|||
if File_api_v1_instance_service_proto != nil {
|
||||
return
|
||||
}
|
||||
file_api_v1_user_service_proto_init()
|
||||
file_api_v1_instance_service_proto_msgTypes[2].OneofWrappers = []any{
|
||||
(*InstanceSetting_GeneralSetting_)(nil),
|
||||
(*InstanceSetting_StorageSetting_)(nil),
|
||||
|
|
|
|||
|
|
@ -222,9 +222,11 @@ type Memo struct {
|
|||
// The name of the creator.
|
||||
// Format: users/{user}
|
||||
Creator string `protobuf:"bytes,3,opt,name=creator,proto3" json:"creator,omitempty"`
|
||||
// Output only. The creation timestamp.
|
||||
// The creation timestamp.
|
||||
// If not set on creation, the server will set it to the current time.
|
||||
CreateTime *timestamppb.Timestamp `protobuf:"bytes,4,opt,name=create_time,json=createTime,proto3" json:"create_time,omitempty"`
|
||||
// Output only. The last update timestamp.
|
||||
// The last update timestamp.
|
||||
// If not set on creation, the server will set it to the current time.
|
||||
UpdateTime *timestamppb.Timestamp `protobuf:"bytes,5,opt,name=update_time,json=updateTime,proto3" json:"update_time,omitempty"`
|
||||
// The display timestamp of the memo.
|
||||
DisplayTime *timestamppb.Timestamp `protobuf:"bytes,6,opt,name=display_time,json=displayTime,proto3" json:"display_time,omitempty"`
|
||||
|
|
@ -1816,9 +1818,9 @@ const file_api_v1_memo_service_proto_rawDesc = "" +
|
|||
"\x05state\x18\x02 \x01(\x0e2\x13.memos.api.v1.StateB\x03\xe0A\x02R\x05state\x123\n" +
|
||||
"\acreator\x18\x03 \x01(\tB\x19\xe0A\x03\xfaA\x13\n" +
|
||||
"\x11memos.api.v1/UserR\acreator\x12@\n" +
|
||||
"\vcreate_time\x18\x04 \x01(\v2\x1a.google.protobuf.TimestampB\x03\xe0A\x03R\n" +
|
||||
"\vcreate_time\x18\x04 \x01(\v2\x1a.google.protobuf.TimestampB\x03\xe0A\x01R\n" +
|
||||
"createTime\x12@\n" +
|
||||
"\vupdate_time\x18\x05 \x01(\v2\x1a.google.protobuf.TimestampB\x03\xe0A\x03R\n" +
|
||||
"\vupdate_time\x18\x05 \x01(\v2\x1a.google.protobuf.TimestampB\x03\xe0A\x01R\n" +
|
||||
"updateTime\x12B\n" +
|
||||
"\fdisplay_time\x18\x06 \x01(\v2\x1a.google.protobuf.TimestampB\x03\xe0A\x01R\vdisplayTime\x12\x1d\n" +
|
||||
"\acontent\x18\a \x01(\tB\x03\xe0A\x02R\acontent\x12=\n" +
|
||||
|
|
|
|||
|
|
@ -29,13 +29,10 @@ const (
|
|||
type User_Role int32
|
||||
|
||||
const (
|
||||
// Unspecified role.
|
||||
User_ROLE_UNSPECIFIED User_Role = 0
|
||||
// Host role with full system access.
|
||||
User_HOST User_Role = 1
|
||||
// Admin role with administrative privileges.
|
||||
// Admin role with system access.
|
||||
User_ADMIN User_Role = 2
|
||||
// Regular user role.
|
||||
// User role with limited access.
|
||||
User_USER User_Role = 3
|
||||
)
|
||||
|
||||
|
|
@ -43,13 +40,11 @@ const (
|
|||
var (
|
||||
User_Role_name = map[int32]string{
|
||||
0: "ROLE_UNSPECIFIED",
|
||||
1: "HOST",
|
||||
2: "ADMIN",
|
||||
3: "USER",
|
||||
}
|
||||
User_Role_value = map[string]int32{
|
||||
"ROLE_UNSPECIFIED": 0,
|
||||
"HOST": 1,
|
||||
"ADMIN": 2,
|
||||
"USER": 3,
|
||||
}
|
||||
|
|
@ -2509,7 +2504,7 @@ var File_api_v1_user_service_proto protoreflect.FileDescriptor
|
|||
|
||||
const file_api_v1_user_service_proto_rawDesc = "" +
|
||||
"\n" +
|
||||
"\x19api/v1/user_service.proto\x12\fmemos.api.v1\x1a\x13api/v1/common.proto\x1a\x1cgoogle/api/annotations.proto\x1a\x17google/api/client.proto\x1a\x1fgoogle/api/field_behavior.proto\x1a\x19google/api/resource.proto\x1a\x1bgoogle/protobuf/empty.proto\x1a google/protobuf/field_mask.proto\x1a\x1fgoogle/protobuf/timestamp.proto\"\xcb\x04\n" +
|
||||
"\x19api/v1/user_service.proto\x12\fmemos.api.v1\x1a\x13api/v1/common.proto\x1a\x1cgoogle/api/annotations.proto\x1a\x17google/api/client.proto\x1a\x1fgoogle/api/field_behavior.proto\x1a\x19google/api/resource.proto\x1a\x1bgoogle/protobuf/empty.proto\x1a google/protobuf/field_mask.proto\x1a\x1fgoogle/protobuf/timestamp.proto\"\xc1\x04\n" +
|
||||
"\x04User\x12\x17\n" +
|
||||
"\x04name\x18\x01 \x01(\tB\x03\xe0A\bR\x04name\x120\n" +
|
||||
"\x04role\x18\x02 \x01(\x0e2\x17.memos.api.v1.User.RoleB\x03\xe0A\x02R\x04role\x12\x1f\n" +
|
||||
|
|
@ -2525,10 +2520,9 @@ const file_api_v1_user_service_proto_rawDesc = "" +
|
|||
" \x01(\v2\x1a.google.protobuf.TimestampB\x03\xe0A\x03R\n" +
|
||||
"createTime\x12@\n" +
|
||||
"\vupdate_time\x18\v \x01(\v2\x1a.google.protobuf.TimestampB\x03\xe0A\x03R\n" +
|
||||
"updateTime\";\n" +
|
||||
"updateTime\"1\n" +
|
||||
"\x04Role\x12\x14\n" +
|
||||
"\x10ROLE_UNSPECIFIED\x10\x00\x12\b\n" +
|
||||
"\x04HOST\x10\x01\x12\t\n" +
|
||||
"\x10ROLE_UNSPECIFIED\x10\x00\x12\t\n" +
|
||||
"\x05ADMIN\x10\x02\x12\b\n" +
|
||||
"\x04USER\x10\x03:7\xeaA4\n" +
|
||||
"\x11memos.api.v1/User\x12\fusers/{user}\x1a\x04name*\x05users2\x04user\"\x9d\x01\n" +
|
||||
|
|
|
|||
|
|
@ -2132,20 +2132,21 @@ components:
|
|||
InstanceProfile:
|
||||
type: object
|
||||
properties:
|
||||
owner:
|
||||
type: string
|
||||
description: |-
|
||||
The name of instance owner.
|
||||
Format: users/{user}
|
||||
version:
|
||||
type: string
|
||||
description: Version is the current version of instance.
|
||||
mode:
|
||||
type: string
|
||||
description: Mode is the instance mode (e.g. "prod", "dev" or "demo").
|
||||
demo:
|
||||
type: boolean
|
||||
description: Demo indicates if the instance is in demo mode.
|
||||
instanceUrl:
|
||||
type: string
|
||||
description: Instance URL is the URL of the instance.
|
||||
admin:
|
||||
allOf:
|
||||
- $ref: '#/components/schemas/User'
|
||||
description: |-
|
||||
The first administrator who set up this instance.
|
||||
When null, instance requires initial setup (creating the first admin account).
|
||||
description: Instance profile message containing basic instance information.
|
||||
InstanceSetting:
|
||||
type: object
|
||||
|
|
@ -2470,14 +2471,16 @@ components:
|
|||
The name of the creator.
|
||||
Format: users/{user}
|
||||
createTime:
|
||||
readOnly: true
|
||||
type: string
|
||||
description: Output only. The creation timestamp.
|
||||
description: |-
|
||||
The creation timestamp.
|
||||
If not set on creation, the server will set it to the current time.
|
||||
format: date-time
|
||||
updateTime:
|
||||
readOnly: true
|
||||
type: string
|
||||
description: Output only. The last update timestamp.
|
||||
description: |-
|
||||
The last update timestamp.
|
||||
If not set on creation, the server will set it to the current time.
|
||||
format: date-time
|
||||
displayTime:
|
||||
type: string
|
||||
|
|
@ -2857,7 +2860,6 @@ components:
|
|||
role:
|
||||
enum:
|
||||
- ROLE_UNSPECIFIED
|
||||
- HOST
|
||||
- ADMIN
|
||||
- USER
|
||||
type: string
|
||||
|
|
|
|||
|
|
@ -1,30 +1,56 @@
|
|||
FROM golang:1.25-alpine AS backend
|
||||
FROM --platform=$BUILDPLATFORM golang:1.25-alpine AS backend
|
||||
WORKDIR /backend-build
|
||||
|
||||
# Install build dependencies
|
||||
RUN apk add --no-cache git ca-certificates
|
||||
|
||||
# Copy go mod files and download dependencies (cached layer)
|
||||
COPY go.mod go.sum ./
|
||||
RUN go mod download
|
||||
RUN --mount=type=cache,target=/go/pkg/mod \
|
||||
go mod download
|
||||
|
||||
# Copy source code (use .dockerignore to exclude unnecessary files)
|
||||
COPY . .
|
||||
|
||||
# Please build frontend first, so that the static files are available.
|
||||
# Refer to `pnpm release` in package.json for the build command.
|
||||
ARG TARGETOS TARGETARCH VERSION COMMIT
|
||||
RUN --mount=type=cache,target=/go/pkg/mod \
|
||||
--mount=type=cache,target=/root/.cache/go-build \
|
||||
go build -ldflags="-s -w" -o memos ./cmd/memos
|
||||
CGO_ENABLED=0 GOOS=$TARGETOS GOARCH=$TARGETARCH \
|
||||
go build \
|
||||
-trimpath \
|
||||
-ldflags="-s -w -extldflags '-static'" \
|
||||
-tags netgo,osusergo \
|
||||
-o memos \
|
||||
./cmd/memos
|
||||
|
||||
FROM alpine:latest AS monolithic
|
||||
WORKDIR /usr/local/memos
|
||||
# Use minimal Alpine with security updates
|
||||
FROM alpine:3.21 AS monolithic
|
||||
|
||||
RUN apk add --no-cache tzdata
|
||||
ENV TZ="UTC"
|
||||
# Install runtime dependencies and create non-root user in single layer
|
||||
RUN apk add --no-cache tzdata ca-certificates su-exec && \
|
||||
addgroup -g 10001 -S nonroot && \
|
||||
adduser -u 10001 -S -G nonroot -h /var/opt/memos nonroot && \
|
||||
mkdir -p /var/opt/memos /usr/local/memos && \
|
||||
chown -R nonroot:nonroot /var/opt/memos
|
||||
|
||||
COPY --from=backend /backend-build/memos /usr/local/memos/
|
||||
COPY ./scripts/entrypoint.sh /usr/local/memos/
|
||||
# Copy binary and entrypoint to /usr/local/memos
|
||||
COPY --from=backend /backend-build/memos /usr/local/memos/memos
|
||||
COPY --from=backend --chmod=755 /backend-build/scripts/entrypoint.sh /usr/local/memos/entrypoint.sh
|
||||
|
||||
# Run as root to fix permissions, entrypoint will drop to nonroot
|
||||
USER root
|
||||
|
||||
# Set working directory to the writable volume
|
||||
WORKDIR /var/opt/memos
|
||||
|
||||
# Data directory
|
||||
VOLUME /var/opt/memos
|
||||
|
||||
ENV TZ="UTC" \
|
||||
MEMOS_PORT="5230"
|
||||
|
||||
EXPOSE 5230
|
||||
|
||||
# Directory to store the data, which can be referenced as the mounting point.
|
||||
RUN mkdir -p /var/opt/memos
|
||||
VOLUME /var/opt/memos
|
||||
|
||||
ENV MEMOS_MODE="prod"
|
||||
ENV MEMOS_PORT="5230"
|
||||
|
||||
ENTRYPOINT ["./entrypoint.sh", "./memos"]
|
||||
ENTRYPOINT ["/usr/local/memos/entrypoint.sh", "/usr/local/memos/memos"]
|
||||
|
|
|
|||
|
|
@ -29,4 +29,4 @@ go build -o "$OUTPUT" ./cmd/memos
|
|||
|
||||
echo "Build successful!"
|
||||
echo "To run the application, execute the following command:"
|
||||
echo "$OUTPUT --mode dev"
|
||||
echo "$OUTPUT"
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
services:
|
||||
memos:
|
||||
image: neosmemo/memos:latest
|
||||
image: neosmemo/memos:stable
|
||||
container_name: memos
|
||||
volumes:
|
||||
- ~/.memos/:/var/opt/memos
|
||||
|
|
|
|||
|
|
@ -1,5 +1,19 @@
|
|||
#!/usr/bin/env sh
|
||||
|
||||
# Fix ownership of data directory for users upgrading from older versions
|
||||
# where files were created as root
|
||||
MEMOS_UID=${MEMOS_UID:-10001}
|
||||
MEMOS_GID=${MEMOS_GID:-10001}
|
||||
DATA_DIR="/var/opt/memos"
|
||||
|
||||
if [ "$(id -u)" = "0" ]; then
|
||||
# Running as root, fix permissions and drop to nonroot
|
||||
if [ -d "$DATA_DIR" ]; then
|
||||
chown -R "$MEMOS_UID:$MEMOS_GID" "$DATA_DIR" 2>/dev/null || true
|
||||
fi
|
||||
exec su-exec "$MEMOS_UID:$MEMOS_GID" "$0" "$@"
|
||||
fi
|
||||
|
||||
file_env() {
|
||||
var="$1"
|
||||
fileVar="${var}_FILE"
|
||||
|
|
|
|||
|
|
@ -74,7 +74,7 @@ func TestParseAccessTokenV2(t *testing.T) {
|
|||
})
|
||||
|
||||
t.Run("parses token with different roles", func(t *testing.T) {
|
||||
roles := []string{"USER", "ADMIN", "HOST"}
|
||||
roles := []string{"USER", "ADMIN"}
|
||||
for _, role := range roles {
|
||||
token, _, err := GenerateAccessTokenV2(1, "testuser", role, "ACTIVE", secret)
|
||||
require.NoError(t, err)
|
||||
|
|
|
|||
|
|
@ -29,8 +29,9 @@ var PublicMethods = map[string]struct{}{
|
|||
"/memos.api.v1.IdentityProviderService/ListIdentityProviders": {},
|
||||
|
||||
// Memo Service - public memos (visibility filtering done in service layer)
|
||||
"/memos.api.v1.MemoService/GetMemo": {},
|
||||
"/memos.api.v1.MemoService/ListMemos": {},
|
||||
"/memos.api.v1.MemoService/GetMemo": {},
|
||||
"/memos.api.v1.MemoService/ListMemos": {},
|
||||
"/memos.api.v1.MemoService/ListMemoComments": {},
|
||||
}
|
||||
|
||||
// IsPublicMethod checks if a procedure path is public (no authentication required).
|
||||
|
|
|
|||
|
|
@ -0,0 +1,191 @@
|
|||
package v1
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"image"
|
||||
"image/color"
|
||||
"image/jpeg"
|
||||
"testing"
|
||||
|
||||
"github.com/disintegration/imaging"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestShouldStripExif(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
mimeType string
|
||||
expected bool
|
||||
}{
|
||||
{
|
||||
name: "JPEG should strip EXIF",
|
||||
mimeType: "image/jpeg",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "JPG should strip EXIF",
|
||||
mimeType: "image/jpg",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "TIFF should strip EXIF",
|
||||
mimeType: "image/tiff",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "WebP should strip EXIF",
|
||||
mimeType: "image/webp",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "HEIC should strip EXIF",
|
||||
mimeType: "image/heic",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "HEIF should strip EXIF",
|
||||
mimeType: "image/heif",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "PNG should not strip EXIF",
|
||||
mimeType: "image/png",
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "GIF should not strip EXIF",
|
||||
mimeType: "image/gif",
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "text file should not strip EXIF",
|
||||
mimeType: "text/plain",
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "PDF should not strip EXIF",
|
||||
mimeType: "application/pdf",
|
||||
expected: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
result := shouldStripExif(tt.mimeType)
|
||||
assert.Equal(t, tt.expected, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestStripImageExif(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// Create a simple test image
|
||||
img := image.NewRGBA(image.Rect(0, 0, 100, 100))
|
||||
// Fill with red color
|
||||
for y := 0; y < 100; y++ {
|
||||
for x := 0; x < 100; x++ {
|
||||
img.Set(x, y, color.RGBA{R: 255, G: 0, B: 0, A: 255})
|
||||
}
|
||||
}
|
||||
|
||||
// Encode as JPEG
|
||||
var buf bytes.Buffer
|
||||
err := jpeg.Encode(&buf, img, &jpeg.Options{Quality: 90})
|
||||
require.NoError(t, err)
|
||||
originalData := buf.Bytes()
|
||||
|
||||
t.Run("strip JPEG metadata", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
strippedData, err := stripImageExif(originalData, "image/jpeg")
|
||||
require.NoError(t, err)
|
||||
assert.NotEmpty(t, strippedData)
|
||||
|
||||
// Verify it's still a valid image
|
||||
decodedImg, err := imaging.Decode(bytes.NewReader(strippedData))
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 100, decodedImg.Bounds().Dx())
|
||||
assert.Equal(t, 100, decodedImg.Bounds().Dy())
|
||||
})
|
||||
|
||||
t.Run("strip JPG metadata (alternate extension)", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
strippedData, err := stripImageExif(originalData, "image/jpg")
|
||||
require.NoError(t, err)
|
||||
assert.NotEmpty(t, strippedData)
|
||||
|
||||
// Verify it's still a valid image
|
||||
decodedImg, err := imaging.Decode(bytes.NewReader(strippedData))
|
||||
require.NoError(t, err)
|
||||
assert.NotNil(t, decodedImg)
|
||||
})
|
||||
|
||||
t.Run("strip PNG metadata", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// Encode as PNG first
|
||||
var pngBuf bytes.Buffer
|
||||
err := imaging.Encode(&pngBuf, img, imaging.PNG)
|
||||
require.NoError(t, err)
|
||||
|
||||
strippedData, err := stripImageExif(pngBuf.Bytes(), "image/png")
|
||||
require.NoError(t, err)
|
||||
assert.NotEmpty(t, strippedData)
|
||||
|
||||
// Verify it's still a valid image
|
||||
decodedImg, err := imaging.Decode(bytes.NewReader(strippedData))
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 100, decodedImg.Bounds().Dx())
|
||||
assert.Equal(t, 100, decodedImg.Bounds().Dy())
|
||||
})
|
||||
|
||||
t.Run("handle WebP format by converting to JPEG", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// WebP format will be converted to JPEG
|
||||
strippedData, err := stripImageExif(originalData, "image/webp")
|
||||
require.NoError(t, err)
|
||||
assert.NotEmpty(t, strippedData)
|
||||
|
||||
// Verify it's a valid image
|
||||
decodedImg, err := imaging.Decode(bytes.NewReader(strippedData))
|
||||
require.NoError(t, err)
|
||||
assert.NotNil(t, decodedImg)
|
||||
})
|
||||
|
||||
t.Run("handle HEIC format by converting to JPEG", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
strippedData, err := stripImageExif(originalData, "image/heic")
|
||||
require.NoError(t, err)
|
||||
assert.NotEmpty(t, strippedData)
|
||||
|
||||
// Verify it's a valid image
|
||||
decodedImg, err := imaging.Decode(bytes.NewReader(strippedData))
|
||||
require.NoError(t, err)
|
||||
assert.NotNil(t, decodedImg)
|
||||
})
|
||||
|
||||
t.Run("return error for invalid image data", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
invalidData := []byte("not an image")
|
||||
_, err := stripImageExif(invalidData, "image/jpeg")
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "failed to decode image")
|
||||
})
|
||||
|
||||
t.Run("return error for empty image data", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
emptyData := []byte{}
|
||||
_, err := stripImageExif(emptyData, "image/jpeg")
|
||||
assert.Error(t, err)
|
||||
})
|
||||
}
|
||||
|
|
@ -6,6 +6,7 @@ import (
|
|||
"encoding/binary"
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
"mime"
|
||||
"net/http"
|
||||
"os"
|
||||
|
|
@ -14,6 +15,7 @@ import (
|
|||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/disintegration/imaging"
|
||||
"github.com/lithammer/shortuuid/v4"
|
||||
"github.com/pkg/errors"
|
||||
"google.golang.org/grpc/codes"
|
||||
|
|
@ -38,6 +40,10 @@ const (
|
|||
MebiByte = 1024 * 1024
|
||||
// ThumbnailCacheFolder is the folder name where the thumbnail images are stored.
|
||||
ThumbnailCacheFolder = ".thumbnail_cache"
|
||||
|
||||
// defaultJPEGQuality is the JPEG quality used when re-encoding images for EXIF stripping.
|
||||
// Quality 95 maintains visual quality while ensuring metadata is removed.
|
||||
defaultJPEGQuality = 95
|
||||
)
|
||||
|
||||
var SupportedThumbnailMimeTypes = []string{
|
||||
|
|
@ -45,6 +51,17 @@ var SupportedThumbnailMimeTypes = []string{
|
|||
"image/jpeg",
|
||||
}
|
||||
|
||||
// exifCapableImageTypes defines image formats that may contain EXIF metadata.
|
||||
// These formats will have their EXIF metadata stripped on upload for privacy.
|
||||
var exifCapableImageTypes = map[string]bool{
|
||||
"image/jpeg": true,
|
||||
"image/jpg": true,
|
||||
"image/tiff": true,
|
||||
"image/webp": true,
|
||||
"image/heic": true,
|
||||
"image/heif": true,
|
||||
}
|
||||
|
||||
func (s *APIV1Service) CreateAttachment(ctx context.Context, request *v1pb.CreateAttachmentRequest) (*v1pb.Attachment, error) {
|
||||
user, err := s.fetchCurrentUser(ctx)
|
||||
if err != nil {
|
||||
|
|
@ -111,6 +128,21 @@ func (s *APIV1Service) CreateAttachment(ctx context.Context, request *v1pb.Creat
|
|||
create.Size = int64(size)
|
||||
create.Blob = request.Attachment.Content
|
||||
|
||||
// Strip EXIF metadata from images for privacy protection.
|
||||
// This removes sensitive information like GPS location, device details, etc.
|
||||
if shouldStripExif(create.Type) {
|
||||
if strippedBlob, err := stripImageExif(create.Blob, create.Type); err != nil {
|
||||
// Log warning but continue with original image to ensure uploads don't fail.
|
||||
slog.Warn("failed to strip EXIF metadata from image",
|
||||
slog.String("type", create.Type),
|
||||
slog.String("filename", create.Filename),
|
||||
slog.String("error", err.Error()))
|
||||
} else {
|
||||
create.Blob = strippedBlob
|
||||
create.Size = int64(len(strippedBlob))
|
||||
}
|
||||
}
|
||||
|
||||
if err := SaveAttachmentBlob(ctx, s.Profile, s.Store, create); err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to save attachment blob: %v", err)
|
||||
}
|
||||
|
|
@ -214,6 +246,12 @@ func (s *APIV1Service) GetAttachment(ctx context.Context, request *v1pb.GetAttac
|
|||
if attachment == nil {
|
||||
return nil, status.Errorf(codes.NotFound, "attachment not found")
|
||||
}
|
||||
|
||||
// Check access permission based on linked memo visibility.
|
||||
if err := s.checkAttachmentAccess(ctx, attachment); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return convertAttachmentFromStore(attachment), nil
|
||||
}
|
||||
|
||||
|
|
@ -225,10 +263,24 @@ func (s *APIV1Service) UpdateAttachment(ctx context.Context, request *v1pb.Updat
|
|||
if request.UpdateMask == nil || len(request.UpdateMask.Paths) == 0 {
|
||||
return nil, status.Errorf(codes.InvalidArgument, "update mask is required")
|
||||
}
|
||||
user, err := s.fetchCurrentUser(ctx)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to get current user: %v", err)
|
||||
}
|
||||
if user == nil {
|
||||
return nil, status.Errorf(codes.Unauthenticated, "user not authenticated")
|
||||
}
|
||||
attachment, err := s.Store.GetAttachment(ctx, &store.FindAttachment{UID: &attachmentUID})
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to get attachment: %v", err)
|
||||
}
|
||||
if attachment == nil {
|
||||
return nil, status.Errorf(codes.NotFound, "attachment not found")
|
||||
}
|
||||
// Only the creator or admin can update the attachment.
|
||||
if attachment.CreatorID != user.ID && !isSuperUser(user) {
|
||||
return nil, status.Errorf(codes.PermissionDenied, "permission denied")
|
||||
}
|
||||
|
||||
currentTs := time.Now().Unix()
|
||||
update := &store.UpdateAttachment{
|
||||
|
|
@ -516,3 +568,89 @@ func (s *APIV1Service) validateAttachmentFilter(ctx context.Context, filterStr s
|
|||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// checkAttachmentAccess verifies the user has permission to access the attachment.
|
||||
// For unlinked attachments (no memo), only the creator can access.
|
||||
// For linked attachments, access follows the memo's visibility rules.
|
||||
func (s *APIV1Service) checkAttachmentAccess(ctx context.Context, attachment *store.Attachment) error {
|
||||
user, _ := s.fetchCurrentUser(ctx)
|
||||
|
||||
// For unlinked attachments, only the creator can access.
|
||||
if attachment.MemoID == nil {
|
||||
if user == nil {
|
||||
return status.Errorf(codes.Unauthenticated, "user not authenticated")
|
||||
}
|
||||
if attachment.CreatorID != user.ID && !isSuperUser(user) {
|
||||
return status.Errorf(codes.PermissionDenied, "permission denied")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// For linked attachments, check memo visibility.
|
||||
memo, err := s.Store.GetMemo(ctx, &store.FindMemo{ID: attachment.MemoID})
|
||||
if err != nil {
|
||||
return status.Errorf(codes.Internal, "failed to get memo: %v", err)
|
||||
}
|
||||
if memo == nil {
|
||||
return status.Errorf(codes.NotFound, "memo not found")
|
||||
}
|
||||
|
||||
if memo.Visibility == store.Public {
|
||||
return nil
|
||||
}
|
||||
if user == nil {
|
||||
return status.Errorf(codes.Unauthenticated, "user not authenticated")
|
||||
}
|
||||
if memo.Visibility == store.Private && memo.CreatorID != user.ID && !isSuperUser(user) {
|
||||
return status.Errorf(codes.PermissionDenied, "permission denied")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// shouldStripExif checks if the MIME type is an image format that may contain EXIF metadata.
|
||||
// Returns true for formats like JPEG, TIFF, WebP, HEIC, and HEIF which commonly contain
|
||||
// privacy-sensitive metadata such as GPS coordinates, camera settings, and device information.
|
||||
func shouldStripExif(mimeType string) bool {
|
||||
return exifCapableImageTypes[mimeType]
|
||||
}
|
||||
|
||||
// stripImageExif removes EXIF metadata from image files by decoding and re-encoding them.
|
||||
// This prevents exposure of sensitive metadata such as GPS location, camera details, and timestamps.
|
||||
//
|
||||
// The function preserves the correct image orientation by applying EXIF orientation tags
|
||||
// during decoding before stripping all metadata. Images are re-encoded with high quality
|
||||
// to minimize visual degradation.
|
||||
//
|
||||
// Supported formats:
|
||||
// - JPEG/JPG: Re-encoded as JPEG with quality 95
|
||||
// - PNG: Re-encoded as PNG (lossless)
|
||||
// - TIFF/WebP/HEIC/HEIF: Re-encoded as JPEG with quality 95
|
||||
//
|
||||
// Returns the cleaned image data without any EXIF metadata, or an error if processing fails.
|
||||
func stripImageExif(imageData []byte, mimeType string) ([]byte, error) {
|
||||
// Decode image with automatic EXIF orientation correction.
|
||||
// This ensures the image displays correctly after metadata removal.
|
||||
img, err := imaging.Decode(bytes.NewReader(imageData), imaging.AutoOrientation(true))
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "failed to decode image")
|
||||
}
|
||||
|
||||
// Re-encode the image without EXIF metadata.
|
||||
var buf bytes.Buffer
|
||||
var encodeErr error
|
||||
|
||||
if mimeType == "image/png" {
|
||||
// Preserve PNG format for lossless encoding
|
||||
encodeErr = imaging.Encode(&buf, img, imaging.PNG)
|
||||
} else {
|
||||
// For JPEG, TIFF, WebP, HEIC, HEIF - re-encode as JPEG.
|
||||
// This ensures EXIF is stripped and provides good compression.
|
||||
encodeErr = imaging.Encode(&buf, img, imaging.JPEG, imaging.JPEGQuality(defaultJPEGQuality))
|
||||
}
|
||||
|
||||
if encodeErr != nil {
|
||||
return nil, errors.Wrap(encodeErr, "failed to encode image")
|
||||
}
|
||||
|
||||
return buf.Bytes(), nil
|
||||
}
|
||||
|
|
|
|||
|
|
@ -64,5 +64,5 @@ func unmarshalPageToken(s string, pageToken *v1pb.PageToken) error {
|
|||
}
|
||||
|
||||
func isSuperUser(user *store.User) bool {
|
||||
return user.Role == store.RoleAdmin || user.Role == store.RoleHost
|
||||
return user.Role == store.RoleAdmin
|
||||
}
|
||||
|
|
|
|||
|
|
@ -5,6 +5,7 @@ import (
|
|||
"errors"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"reflect"
|
||||
"runtime/debug"
|
||||
|
||||
"connectrpc.com/connect"
|
||||
|
|
@ -50,10 +51,30 @@ func (*MetadataInterceptor) WrapUnary(next connect.UnaryFunc) connect.UnaryFunc
|
|||
|
||||
// Set metadata in context so services can use metadata.FromIncomingContext()
|
||||
ctx = metadata.NewIncomingContext(ctx, md)
|
||||
return next(ctx, req)
|
||||
|
||||
// Execute the request
|
||||
resp, err := next(ctx, req)
|
||||
|
||||
// Prevent browser caching of API responses to avoid stale data issues
|
||||
// See: https://github.com/usememos/memos/issues/5470
|
||||
if !isNilAnyResponse(resp) && resp.Header() != nil {
|
||||
resp.Header().Set("Cache-Control", "no-cache, no-store, must-revalidate")
|
||||
resp.Header().Set("Pragma", "no-cache")
|
||||
resp.Header().Set("Expires", "0")
|
||||
}
|
||||
|
||||
return resp, err
|
||||
}
|
||||
}
|
||||
|
||||
func isNilAnyResponse(resp connect.AnyResponse) bool {
|
||||
if resp == nil {
|
||||
return true
|
||||
}
|
||||
val := reflect.ValueOf(resp)
|
||||
return val.Kind() == reflect.Ptr && val.IsNil()
|
||||
}
|
||||
|
||||
func (*MetadataInterceptor) WrapStreamingClient(next connect.StreamingClientFunc) connect.StreamingClientFunc {
|
||||
return next
|
||||
}
|
||||
|
|
|
|||
|
|
@ -18,7 +18,10 @@ func (s *APIV1Service) CreateIdentityProvider(ctx context.Context, request *v1pb
|
|||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to get user: %v", err)
|
||||
}
|
||||
if currentUser == nil || currentUser.Role != store.RoleHost {
|
||||
if currentUser == nil {
|
||||
return nil, status.Errorf(codes.Unauthenticated, "user not authenticated")
|
||||
}
|
||||
if currentUser.Role != store.RoleAdmin {
|
||||
return nil, status.Errorf(codes.PermissionDenied, "permission denied")
|
||||
}
|
||||
|
||||
|
|
@ -84,7 +87,10 @@ func (s *APIV1Service) UpdateIdentityProvider(ctx context.Context, request *v1pb
|
|||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to get user: %v", err)
|
||||
}
|
||||
if currentUser == nil || currentUser.Role != store.RoleHost {
|
||||
if currentUser == nil {
|
||||
return nil, status.Errorf(codes.Unauthenticated, "user not authenticated")
|
||||
}
|
||||
if currentUser.Role != store.RoleAdmin {
|
||||
return nil, status.Errorf(codes.PermissionDenied, "permission denied")
|
||||
}
|
||||
|
||||
|
|
@ -125,7 +131,10 @@ func (s *APIV1Service) DeleteIdentityProvider(ctx context.Context, request *v1pb
|
|||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to get user: %v", err)
|
||||
}
|
||||
if currentUser == nil || currentUser.Role != store.RoleHost {
|
||||
if currentUser == nil {
|
||||
return nil, status.Errorf(codes.Unauthenticated, "user not authenticated")
|
||||
}
|
||||
if currentUser.Role != store.RoleAdmin {
|
||||
return nil, status.Errorf(codes.PermissionDenied, "permission denied")
|
||||
}
|
||||
|
||||
|
|
@ -219,7 +228,7 @@ func convertIdentityProviderConfigToStore(identityProviderType v1pb.IdentityProv
|
|||
}
|
||||
|
||||
func redactIdentityProviderResponse(identityProvider *v1pb.IdentityProvider, userRole store.Role) *v1pb.IdentityProvider {
|
||||
if userRole != store.RoleHost {
|
||||
if userRole != store.RoleAdmin {
|
||||
if identityProvider.Type == v1pb.IdentityProvider_OAUTH2 {
|
||||
identityProvider.Config.GetOauth2Config().ClientSecret = ""
|
||||
}
|
||||
|
|
|
|||
|
|
@ -15,17 +15,16 @@ import (
|
|||
|
||||
// GetInstanceProfile returns the instance profile.
|
||||
func (s *APIV1Service) GetInstanceProfile(ctx context.Context, _ *v1pb.GetInstanceProfileRequest) (*v1pb.InstanceProfile, error) {
|
||||
admin, err := s.GetInstanceAdmin(ctx)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to get instance admin: %v", err)
|
||||
}
|
||||
|
||||
instanceProfile := &v1pb.InstanceProfile{
|
||||
Version: s.Profile.Version,
|
||||
Mode: s.Profile.Mode,
|
||||
Demo: s.Profile.Demo,
|
||||
InstanceUrl: s.Profile.InstanceURL,
|
||||
}
|
||||
owner, err := s.GetInstanceOwner(ctx)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to get instance owner: %v", err)
|
||||
}
|
||||
if owner != nil {
|
||||
instanceProfile.Owner = owner.Name
|
||||
Admin: admin, // nil when not initialized
|
||||
}
|
||||
return instanceProfile, nil
|
||||
}
|
||||
|
|
@ -64,13 +63,16 @@ func (s *APIV1Service) GetInstanceSetting(ctx context.Context, request *v1pb.Get
|
|||
return nil, status.Errorf(codes.NotFound, "instance setting not found")
|
||||
}
|
||||
|
||||
// For storage setting, only host can get it.
|
||||
// For storage setting, only admin can get it.
|
||||
if instanceSetting.Key == storepb.InstanceSettingKey_STORAGE {
|
||||
user, err := s.fetchCurrentUser(ctx)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to get current user: %v", err)
|
||||
}
|
||||
if user == nil || user.Role != store.RoleHost {
|
||||
if user == nil {
|
||||
return nil, status.Errorf(codes.Unauthenticated, "user not authenticated")
|
||||
}
|
||||
if user.Role != store.RoleAdmin {
|
||||
return nil, status.Errorf(codes.PermissionDenied, "permission denied")
|
||||
}
|
||||
}
|
||||
|
|
@ -86,7 +88,7 @@ func (s *APIV1Service) UpdateInstanceSetting(ctx context.Context, request *v1pb.
|
|||
if user == nil {
|
||||
return nil, status.Errorf(codes.Unauthenticated, "user not authenticated")
|
||||
}
|
||||
if user.Role != store.RoleHost {
|
||||
if user.Role != store.RoleAdmin {
|
||||
return nil, status.Errorf(codes.PermissionDenied, "permission denied")
|
||||
}
|
||||
|
||||
|
|
@ -267,24 +269,17 @@ func convertInstanceMemoRelatedSettingToStore(setting *v1pb.InstanceSetting_Memo
|
|||
}
|
||||
}
|
||||
|
||||
var ownerCache *v1pb.User
|
||||
|
||||
func (s *APIV1Service) GetInstanceOwner(ctx context.Context) (*v1pb.User, error) {
|
||||
if ownerCache != nil {
|
||||
return ownerCache, nil
|
||||
}
|
||||
|
||||
hostUserType := store.RoleHost
|
||||
func (s *APIV1Service) GetInstanceAdmin(ctx context.Context) (*v1pb.User, error) {
|
||||
adminUserType := store.RoleAdmin
|
||||
user, err := s.Store.GetUser(ctx, &store.FindUser{
|
||||
Role: &hostUserType,
|
||||
Role: &adminUserType,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, errors.Wrapf(err, "failed to find owner")
|
||||
return nil, errors.Wrapf(err, "failed to find admin")
|
||||
}
|
||||
if user == nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
ownerCache = convertUserFromStore(user)
|
||||
return ownerCache, nil
|
||||
return convertUserFromStore(user), nil
|
||||
}
|
||||
|
|
|
|||
|
|
@ -98,6 +98,24 @@ func (s *APIV1Service) ListMemoAttachments(ctx context.Context, request *v1pb.Li
|
|||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to get memo: %v", err)
|
||||
}
|
||||
if memo == nil {
|
||||
return nil, status.Errorf(codes.NotFound, "memo not found")
|
||||
}
|
||||
|
||||
// Check memo visibility.
|
||||
if memo.Visibility != store.Public {
|
||||
user, err := s.fetchCurrentUser(ctx)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to get user: %v", err)
|
||||
}
|
||||
if user == nil {
|
||||
return nil, status.Errorf(codes.Unauthenticated, "user not authenticated")
|
||||
}
|
||||
if memo.Visibility == store.Private && memo.CreatorID != user.ID && !isSuperUser(user) {
|
||||
return nil, status.Errorf(codes.PermissionDenied, "permission denied")
|
||||
}
|
||||
}
|
||||
|
||||
attachments, err := s.Store.ListAttachments(ctx, &store.FindAttachment{
|
||||
MemoID: &memo.ID,
|
||||
})
|
||||
|
|
|
|||
|
|
@ -45,10 +45,35 @@ func (s *APIV1Service) CreateMemo(ctx context.Context, request *v1pb.CreateMemoR
|
|||
Content: request.Memo.Content,
|
||||
Visibility: convertVisibilityToStore(request.Memo.Visibility),
|
||||
}
|
||||
|
||||
instanceMemoRelatedSetting, err := s.Store.GetInstanceMemoRelatedSetting(ctx)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to get instance memo related setting")
|
||||
}
|
||||
|
||||
// Handle display_time first: if provided, use it to set the appropriate timestamp
|
||||
// based on the instance setting (similar to UpdateMemo logic)
|
||||
// Note: explicit create_time/update_time below will override this if provided
|
||||
if request.Memo.DisplayTime != nil && request.Memo.DisplayTime.IsValid() {
|
||||
displayTs := request.Memo.DisplayTime.AsTime().Unix()
|
||||
if instanceMemoRelatedSetting.DisplayWithUpdateTime {
|
||||
create.UpdatedTs = displayTs
|
||||
} else {
|
||||
create.CreatedTs = displayTs
|
||||
}
|
||||
}
|
||||
|
||||
// Set custom timestamps if provided in the request
|
||||
// These take precedence over display_time
|
||||
if request.Memo.CreateTime != nil && request.Memo.CreateTime.IsValid() {
|
||||
createdTs := request.Memo.CreateTime.AsTime().Unix()
|
||||
create.CreatedTs = createdTs
|
||||
}
|
||||
if request.Memo.UpdateTime != nil && request.Memo.UpdateTime.IsValid() {
|
||||
updatedTs := request.Memo.UpdateTime.AsTime().Unix()
|
||||
create.UpdatedTs = updatedTs
|
||||
}
|
||||
|
||||
if instanceMemoRelatedSetting.DisallowPublicVisibility && create.Visibility == store.Public {
|
||||
return nil, status.Errorf(codes.PermissionDenied, "disable public memos system setting is enabled")
|
||||
}
|
||||
|
|
@ -281,7 +306,7 @@ func (s *APIV1Service) GetMemo(ctx context.Context, request *v1pb.GetMemoRequest
|
|||
return nil, status.Errorf(codes.Internal, "failed to get user")
|
||||
}
|
||||
if user == nil {
|
||||
return nil, status.Errorf(codes.PermissionDenied, "permission denied")
|
||||
return nil, status.Errorf(codes.Unauthenticated, "user not authenticated")
|
||||
}
|
||||
if memo.Visibility == store.Private && memo.CreatorID != user.ID {
|
||||
return nil, status.Errorf(codes.PermissionDenied, "permission denied")
|
||||
|
|
@ -497,23 +522,7 @@ func (s *APIV1Service) DeleteMemo(ctx context.Context, request *v1pb.DeleteMemoR
|
|||
}
|
||||
}
|
||||
|
||||
if err = s.Store.DeleteMemo(ctx, &store.DeleteMemo{ID: memo.ID}); err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to delete memo")
|
||||
}
|
||||
|
||||
// Delete memo relation
|
||||
if err := s.Store.DeleteMemoRelation(ctx, &store.DeleteMemoRelation{MemoID: &memo.ID}); err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to delete memo relations")
|
||||
}
|
||||
|
||||
// Delete related attachments.
|
||||
for _, attachment := range attachments {
|
||||
if err := s.Store.DeleteAttachment(ctx, &store.DeleteAttachment{ID: attachment.ID}); err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to delete attachment")
|
||||
}
|
||||
}
|
||||
|
||||
// Delete memo comments
|
||||
// Delete memo comments first (store.DeleteMemo handles their relations and attachments)
|
||||
commentType := store.MemoRelationComment
|
||||
relations, err := s.Store.ListMemoRelations(ctx, &store.FindMemoRelation{RelatedMemoID: &memo.ID, Type: &commentType})
|
||||
if err != nil {
|
||||
|
|
@ -525,10 +534,9 @@ func (s *APIV1Service) DeleteMemo(ctx context.Context, request *v1pb.DeleteMemoR
|
|||
}
|
||||
}
|
||||
|
||||
// Delete memo references
|
||||
referenceType := store.MemoRelationReference
|
||||
if err := s.Store.DeleteMemoRelation(ctx, &store.DeleteMemoRelation{RelatedMemoID: &memo.ID, Type: &referenceType}); err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to delete memo references")
|
||||
// Delete the memo (store.DeleteMemo handles relation and attachment cleanup)
|
||||
if err = s.Store.DeleteMemo(ctx, &store.DeleteMemo{ID: memo.ID}); err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to delete memo")
|
||||
}
|
||||
|
||||
return &emptypb.Empty{}, nil
|
||||
|
|
@ -543,6 +551,21 @@ func (s *APIV1Service) CreateMemoComment(ctx context.Context, request *v1pb.Crea
|
|||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to get memo")
|
||||
}
|
||||
if relatedMemo == nil {
|
||||
return nil, status.Errorf(codes.NotFound, "memo not found")
|
||||
}
|
||||
|
||||
// Check memo visibility before allowing comment.
|
||||
user, err := s.fetchCurrentUser(ctx)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to get user")
|
||||
}
|
||||
if user == nil {
|
||||
return nil, status.Errorf(codes.Unauthenticated, "user not authenticated")
|
||||
}
|
||||
if relatedMemo.Visibility == store.Private && relatedMemo.CreatorID != user.ID && !isSuperUser(user) {
|
||||
return nil, status.Errorf(codes.PermissionDenied, "permission denied")
|
||||
}
|
||||
|
||||
// Create the memo comment first.
|
||||
memoComment, err := s.CreateMemo(ctx, &v1pb.CreateMemoRequest{
|
||||
|
|
|
|||
|
|
@ -15,6 +15,33 @@ import (
|
|||
)
|
||||
|
||||
func (s *APIV1Service) ListMemoReactions(ctx context.Context, request *v1pb.ListMemoReactionsRequest) (*v1pb.ListMemoReactionsResponse, error) {
|
||||
// Extract memo UID and check visibility.
|
||||
memoUID, err := ExtractMemoUIDFromName(request.Name)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.InvalidArgument, "invalid memo name: %v", err)
|
||||
}
|
||||
memo, err := s.Store.GetMemo(ctx, &store.FindMemo{UID: &memoUID})
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to get memo: %v", err)
|
||||
}
|
||||
if memo == nil {
|
||||
return nil, status.Errorf(codes.NotFound, "memo not found")
|
||||
}
|
||||
|
||||
// Check memo visibility.
|
||||
if memo.Visibility != store.Public {
|
||||
user, err := s.fetchCurrentUser(ctx)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to get user")
|
||||
}
|
||||
if user == nil {
|
||||
return nil, status.Errorf(codes.Unauthenticated, "user not authenticated")
|
||||
}
|
||||
if memo.Visibility == store.Private && memo.CreatorID != user.ID && !isSuperUser(user) {
|
||||
return nil, status.Errorf(codes.PermissionDenied, "permission denied")
|
||||
}
|
||||
}
|
||||
|
||||
reactions, err := s.Store.ListReactions(ctx, &store.FindReaction{
|
||||
ContentID: &request.Name,
|
||||
})
|
||||
|
|
@ -40,6 +67,25 @@ func (s *APIV1Service) UpsertMemoReaction(ctx context.Context, request *v1pb.Ups
|
|||
if user == nil {
|
||||
return nil, status.Errorf(codes.Unauthenticated, "user not authenticated")
|
||||
}
|
||||
|
||||
// Extract memo UID and check visibility before allowing reaction.
|
||||
memoUID, err := ExtractMemoUIDFromName(request.Reaction.ContentId)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.InvalidArgument, "invalid memo name: %v", err)
|
||||
}
|
||||
memo, err := s.Store.GetMemo(ctx, &store.FindMemo{UID: &memoUID})
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to get memo: %v", err)
|
||||
}
|
||||
if memo == nil {
|
||||
return nil, status.Errorf(codes.NotFound, "memo not found")
|
||||
}
|
||||
|
||||
// Check memo visibility.
|
||||
if memo.Visibility == store.Private && memo.CreatorID != user.ID && !isSuperUser(user) {
|
||||
return nil, status.Errorf(codes.PermissionDenied, "permission denied")
|
||||
}
|
||||
|
||||
reaction, err := s.Store.UpsertReaction(ctx, &store.Reaction{
|
||||
CreatorID: user.ID,
|
||||
ContentID: request.Reaction.ContentId,
|
||||
|
|
|
|||
|
|
@ -97,7 +97,7 @@ func TestCreateIdentityProvider(t *testing.T) {
|
|||
|
||||
_, err := ts.Service.CreateIdentityProvider(ctx, req)
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "permission denied")
|
||||
require.Contains(t, err.Error(), "user not authenticated")
|
||||
})
|
||||
}
|
||||
|
||||
|
|
@ -547,6 +547,6 @@ func TestIdentityProviderPermissions(t *testing.T) {
|
|||
|
||||
_, err := ts.Service.CreateIdentityProvider(ctx, req)
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "permission denied")
|
||||
require.Contains(t, err.Error(), "user not authenticated")
|
||||
})
|
||||
}
|
||||
|
|
|
|||
|
|
@ -0,0 +1,54 @@
|
|||
package test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
v1pb "github.com/usememos/memos/proto/gen/api/v1"
|
||||
)
|
||||
|
||||
func TestInstanceAdminRetrieval(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
t.Run("Instance becomes initialized after first admin user is created", func(t *testing.T) {
|
||||
// Create test service
|
||||
ts := NewTestService(t)
|
||||
defer ts.Cleanup()
|
||||
|
||||
// Verify instance is not initialized initially
|
||||
profile1, err := ts.Service.GetInstanceProfile(ctx, &v1pb.GetInstanceProfileRequest{})
|
||||
require.NoError(t, err)
|
||||
require.Nil(t, profile1.Admin, "Instance should not be initialized before first admin user")
|
||||
|
||||
// Create the first admin user
|
||||
user, err := ts.CreateHostUser(ctx, "admin")
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, user)
|
||||
|
||||
// Verify instance is now initialized
|
||||
profile2, err := ts.Service.GetInstanceProfile(ctx, &v1pb.GetInstanceProfileRequest{})
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, profile2.Admin, "Instance should be initialized after first admin user is created")
|
||||
require.Equal(t, user.Username, profile2.Admin.Username)
|
||||
})
|
||||
|
||||
t.Run("Admin retrieval is cached by Store layer", func(t *testing.T) {
|
||||
// Create test service
|
||||
ts := NewTestService(t)
|
||||
defer ts.Cleanup()
|
||||
|
||||
// Create admin user
|
||||
user, err := ts.CreateHostUser(ctx, "admin")
|
||||
require.NoError(t, err)
|
||||
|
||||
// Multiple calls should return consistent admin user (from cache)
|
||||
for i := 0; i < 5; i++ {
|
||||
profile, err := ts.Service.GetInstanceProfile(ctx, &v1pb.GetInstanceProfileRequest{})
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, profile.Admin)
|
||||
require.Equal(t, user.Username, profile.Admin.Username)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
|
@ -2,7 +2,6 @@ package test
|
|||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
|
|
@ -28,14 +27,14 @@ func TestGetInstanceProfile(t *testing.T) {
|
|||
|
||||
// Verify the response contains expected data
|
||||
require.Equal(t, "test-1.0.0", resp.Version)
|
||||
require.Equal(t, "dev", resp.Mode)
|
||||
require.True(t, resp.Demo)
|
||||
require.Equal(t, "http://localhost:8080", resp.InstanceUrl)
|
||||
|
||||
// Owner should be empty since no users are created
|
||||
require.Empty(t, resp.Owner)
|
||||
// Instance should not be initialized since no admin users are created
|
||||
require.Nil(t, resp.Admin)
|
||||
})
|
||||
|
||||
t.Run("GetInstanceProfile with owner", func(t *testing.T) {
|
||||
t.Run("GetInstanceProfile with initialized instance", func(t *testing.T) {
|
||||
// Create test service for this specific test
|
||||
ts := NewTestService(t)
|
||||
defer ts.Cleanup()
|
||||
|
|
@ -53,14 +52,14 @@ func TestGetInstanceProfile(t *testing.T) {
|
|||
require.NoError(t, err)
|
||||
require.NotNil(t, resp)
|
||||
|
||||
// Verify the response contains expected data including owner
|
||||
// Verify the response contains expected data with initialized flag
|
||||
require.Equal(t, "test-1.0.0", resp.Version)
|
||||
require.Equal(t, "dev", resp.Mode)
|
||||
require.True(t, resp.Demo)
|
||||
require.Equal(t, "http://localhost:8080", resp.InstanceUrl)
|
||||
|
||||
// User name should be "users/{id}" format where id is the user's ID
|
||||
expectedOwnerName := fmt.Sprintf("users/%d", hostUser.ID)
|
||||
require.Equal(t, expectedOwnerName, resp.Owner)
|
||||
// Instance should be initialized since an admin user exists
|
||||
require.NotNil(t, resp.Admin)
|
||||
require.Equal(t, hostUser.Username, resp.Admin.Username)
|
||||
})
|
||||
}
|
||||
|
||||
|
|
@ -73,9 +72,8 @@ func TestGetInstanceProfile_Concurrency(t *testing.T) {
|
|||
defer ts.Cleanup()
|
||||
|
||||
// Create a host user
|
||||
hostUser, err := ts.CreateHostUser(ctx, "admin")
|
||||
_, err := ts.CreateHostUser(ctx, "admin")
|
||||
require.NoError(t, err)
|
||||
expectedOwnerName := fmt.Sprintf("users/%d", hostUser.ID)
|
||||
|
||||
// Make concurrent requests
|
||||
numGoroutines := 10
|
||||
|
|
@ -102,9 +100,9 @@ func TestGetInstanceProfile_Concurrency(t *testing.T) {
|
|||
case resp := <-results:
|
||||
require.NotNil(t, resp)
|
||||
require.Equal(t, "test-1.0.0", resp.Version)
|
||||
require.Equal(t, "dev", resp.Mode)
|
||||
require.True(t, resp.Demo)
|
||||
require.Equal(t, "http://localhost:8080", resp.InstanceUrl)
|
||||
require.Equal(t, expectedOwnerName, resp.Owner)
|
||||
require.NotNil(t, resp.Admin)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
|
|
|||
|
|
@ -5,8 +5,10 @@ import (
|
|||
"fmt"
|
||||
"slices"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
"google.golang.org/protobuf/types/known/timestamppb"
|
||||
|
||||
apiv1 "github.com/usememos/memos/proto/gen/api/v1"
|
||||
)
|
||||
|
|
@ -250,3 +252,118 @@ func TestListMemos(t *testing.T) {
|
|||
require.NotNil(t, userTwoReaction)
|
||||
require.Equal(t, "👍", userTwoReaction.ReactionType)
|
||||
}
|
||||
|
||||
// TestCreateMemoWithCustomTimestamps tests that custom timestamps can be set when creating memos and comments.
|
||||
// This addresses issue #5483: https://github.com/usememos/memos/issues/5483
|
||||
func TestCreateMemoWithCustomTimestamps(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
ts := NewTestService(t)
|
||||
defer ts.Cleanup()
|
||||
|
||||
// Create a test user
|
||||
user, err := ts.CreateRegularUser(ctx, "test-user-timestamps")
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, user)
|
||||
|
||||
userCtx := ts.CreateUserContext(ctx, user.ID)
|
||||
|
||||
// Define custom timestamps (January 1, 2020)
|
||||
customCreateTime := time.Date(2020, 1, 1, 12, 0, 0, 0, time.UTC)
|
||||
customUpdateTime := time.Date(2020, 1, 2, 12, 0, 0, 0, time.UTC)
|
||||
customDisplayTime := time.Date(2020, 1, 3, 12, 0, 0, 0, time.UTC)
|
||||
|
||||
// Test 1: Create a memo with custom create_time
|
||||
memoWithCreateTime, err := ts.Service.CreateMemo(userCtx, &apiv1.CreateMemoRequest{
|
||||
Memo: &apiv1.Memo{
|
||||
Content: "This memo has a custom creation time",
|
||||
Visibility: apiv1.Visibility_PRIVATE,
|
||||
CreateTime: timestamppb.New(customCreateTime),
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, memoWithCreateTime)
|
||||
require.Equal(t, customCreateTime.Unix(), memoWithCreateTime.CreateTime.AsTime().Unix(), "create_time should match the custom timestamp")
|
||||
|
||||
// Test 2: Create a memo with custom update_time
|
||||
memoWithUpdateTime, err := ts.Service.CreateMemo(userCtx, &apiv1.CreateMemoRequest{
|
||||
Memo: &apiv1.Memo{
|
||||
Content: "This memo has a custom update time",
|
||||
Visibility: apiv1.Visibility_PRIVATE,
|
||||
UpdateTime: timestamppb.New(customUpdateTime),
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, memoWithUpdateTime)
|
||||
require.Equal(t, customUpdateTime.Unix(), memoWithUpdateTime.UpdateTime.AsTime().Unix(), "update_time should match the custom timestamp")
|
||||
|
||||
// Test 3: Create a memo with custom display_time
|
||||
// Note: display_time is computed from either created_ts or updated_ts based on instance setting
|
||||
// Since DisplayWithUpdateTime defaults to false, display_time maps to created_ts
|
||||
memoWithDisplayTime, err := ts.Service.CreateMemo(userCtx, &apiv1.CreateMemoRequest{
|
||||
Memo: &apiv1.Memo{
|
||||
Content: "This memo has a custom display time",
|
||||
Visibility: apiv1.Visibility_PRIVATE,
|
||||
DisplayTime: timestamppb.New(customDisplayTime),
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, memoWithDisplayTime)
|
||||
// Since DisplayWithUpdateTime is false by default, display_time sets created_ts
|
||||
require.Equal(t, customDisplayTime.Unix(), memoWithDisplayTime.DisplayTime.AsTime().Unix(), "display_time should match the custom timestamp")
|
||||
require.Equal(t, customDisplayTime.Unix(), memoWithDisplayTime.CreateTime.AsTime().Unix(), "create_time should also match since display_time maps to created_ts")
|
||||
|
||||
// Test 4: Create a memo with all custom timestamps
|
||||
// When both display_time and create_time are provided, create_time takes precedence
|
||||
memoWithAllTimestamps, err := ts.Service.CreateMemo(userCtx, &apiv1.CreateMemoRequest{
|
||||
Memo: &apiv1.Memo{
|
||||
Content: "This memo has all custom timestamps",
|
||||
Visibility: apiv1.Visibility_PRIVATE,
|
||||
CreateTime: timestamppb.New(customCreateTime),
|
||||
UpdateTime: timestamppb.New(customUpdateTime),
|
||||
DisplayTime: timestamppb.New(customDisplayTime),
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, memoWithAllTimestamps)
|
||||
require.Equal(t, customCreateTime.Unix(), memoWithAllTimestamps.CreateTime.AsTime().Unix(), "create_time should match the custom timestamp")
|
||||
require.Equal(t, customUpdateTime.Unix(), memoWithAllTimestamps.UpdateTime.AsTime().Unix(), "update_time should match the custom timestamp")
|
||||
// display_time is computed from created_ts when DisplayWithUpdateTime is false
|
||||
require.Equal(t, customCreateTime.Unix(), memoWithAllTimestamps.DisplayTime.AsTime().Unix(), "display_time should be derived from create_time")
|
||||
|
||||
// Test 5: Create a comment (memo relation) with custom timestamps
|
||||
parentMemo, err := ts.Service.CreateMemo(userCtx, &apiv1.CreateMemoRequest{
|
||||
Memo: &apiv1.Memo{
|
||||
Content: "This is the parent memo",
|
||||
Visibility: apiv1.Visibility_PRIVATE,
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, parentMemo)
|
||||
|
||||
customCommentCreateTime := time.Date(2021, 6, 15, 10, 30, 0, 0, time.UTC)
|
||||
comment, err := ts.Service.CreateMemoComment(userCtx, &apiv1.CreateMemoCommentRequest{
|
||||
Name: parentMemo.Name,
|
||||
Comment: &apiv1.Memo{
|
||||
Content: "This is a comment with custom create time",
|
||||
Visibility: apiv1.Visibility_PRIVATE,
|
||||
CreateTime: timestamppb.New(customCommentCreateTime),
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, comment)
|
||||
require.Equal(t, customCommentCreateTime.Unix(), comment.CreateTime.AsTime().Unix(), "comment create_time should match the custom timestamp")
|
||||
|
||||
// Test 6: Verify that memos without custom timestamps still get auto-generated ones
|
||||
memoWithoutTimestamps, err := ts.Service.CreateMemo(userCtx, &apiv1.CreateMemoRequest{
|
||||
Memo: &apiv1.Memo{
|
||||
Content: "This memo has auto-generated timestamps",
|
||||
Visibility: apiv1.Visibility_PRIVATE,
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, memoWithoutTimestamps)
|
||||
require.NotNil(t, memoWithoutTimestamps.CreateTime, "create_time should be auto-generated")
|
||||
require.NotNil(t, memoWithoutTimestamps.UpdateTime, "update_time should be auto-generated")
|
||||
require.True(t, time.Now().Unix()-memoWithoutTimestamps.CreateTime.AsTime().Unix() < 5, "create_time should be recent (within 5 seconds)")
|
||||
}
|
||||
|
|
|
|||
|
|
@ -29,7 +29,7 @@ func NewTestService(t *testing.T) *TestService {
|
|||
|
||||
// Create a test profile
|
||||
testProfile := &profile.Profile{
|
||||
Mode: "dev",
|
||||
Demo: true,
|
||||
Version: "test-1.0.0",
|
||||
InstanceURL: "http://localhost:8080",
|
||||
Driver: "sqlite",
|
||||
|
|
@ -56,17 +56,16 @@ func NewTestService(t *testing.T) *TestService {
|
|||
}
|
||||
}
|
||||
|
||||
// Cleanup clears caches and closes resources after test.
|
||||
// Cleanup closes resources after test.
|
||||
func (ts *TestService) Cleanup() {
|
||||
ts.Store.Close()
|
||||
// Note: Owner cache is package-level in parent package, cannot clear from test package
|
||||
}
|
||||
|
||||
// CreateHostUser creates a host user for testing.
|
||||
// CreateHostUser creates an admin user for testing.
|
||||
func (ts *TestService) CreateHostUser(ctx context.Context, username string) (*store.User, error) {
|
||||
return ts.Store.CreateUser(ctx, &store.User{
|
||||
Username: username,
|
||||
Role: store.RoleHost,
|
||||
Role: store.RoleAdmin,
|
||||
Email: username + "@example.com",
|
||||
})
|
||||
}
|
||||
|
|
|
|||
|
|
@ -37,7 +37,7 @@ func (s *APIV1Service) ListUsers(ctx context.Context, request *v1pb.ListUsersReq
|
|||
if currentUser == nil {
|
||||
return nil, status.Errorf(codes.Unauthenticated, "user not authenticated")
|
||||
}
|
||||
if currentUser.Role != store.RoleHost && currentUser.Role != store.RoleAdmin {
|
||||
if currentUser.Role != store.RoleAdmin {
|
||||
return nil, status.Errorf(codes.PermissionDenied, "permission denied")
|
||||
}
|
||||
|
||||
|
|
@ -132,17 +132,17 @@ func (s *APIV1Service) CreateUser(ctx context.Context, request *v1pb.CreateUserR
|
|||
// Determine the role to assign
|
||||
var roleToAssign store.Role
|
||||
if isFirstUser {
|
||||
// First-time setup: create the first user as HOST (no authentication required)
|
||||
roleToAssign = store.RoleHost
|
||||
} else if currentUser != nil && currentUser.Role == store.RoleHost {
|
||||
// Authenticated HOST user can create users with any role specified in request
|
||||
// First-time setup: create the first user as ADMIN (no authentication required)
|
||||
roleToAssign = store.RoleAdmin
|
||||
} else if currentUser != nil && currentUser.Role == store.RoleAdmin {
|
||||
// Authenticated ADMIN user can create users with any role specified in request
|
||||
if request.User.Role != v1pb.User_ROLE_UNSPECIFIED {
|
||||
roleToAssign = convertUserRoleToStore(request.User.Role)
|
||||
} else {
|
||||
roleToAssign = store.RoleUser
|
||||
}
|
||||
} else {
|
||||
// Unauthenticated or non-HOST users can only create normal users
|
||||
// Unauthenticated or non-ADMIN users can only create normal users
|
||||
roleToAssign = store.RoleUser
|
||||
}
|
||||
|
||||
|
|
@ -192,9 +192,12 @@ func (s *APIV1Service) UpdateUser(ctx context.Context, request *v1pb.UpdateUserR
|
|||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to get user: %v", err)
|
||||
}
|
||||
if currentUser == nil {
|
||||
return nil, status.Errorf(codes.Unauthenticated, "user not authenticated")
|
||||
}
|
||||
// Check permission.
|
||||
// Only allow admin or self to update user.
|
||||
if currentUser.ID != userID && currentUser.Role != store.RoleAdmin && currentUser.Role != store.RoleHost {
|
||||
if currentUser.ID != userID && currentUser.Role != store.RoleAdmin {
|
||||
return nil, status.Errorf(codes.PermissionDenied, "permission denied")
|
||||
}
|
||||
|
||||
|
|
@ -261,7 +264,7 @@ func (s *APIV1Service) UpdateUser(ctx context.Context, request *v1pb.UpdateUserR
|
|||
update.Description = &request.User.Description
|
||||
case "role":
|
||||
// Only allow admin to update role.
|
||||
if currentUser.Role != store.RoleAdmin && currentUser.Role != store.RoleHost {
|
||||
if currentUser.Role != store.RoleAdmin {
|
||||
return nil, status.Errorf(codes.PermissionDenied, "permission denied")
|
||||
}
|
||||
role := convertUserRoleToStore(request.User.Role)
|
||||
|
|
@ -298,7 +301,10 @@ func (s *APIV1Service) DeleteUser(ctx context.Context, request *v1pb.DeleteUserR
|
|||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to get user: %v", err)
|
||||
}
|
||||
if currentUser.ID != userID && currentUser.Role != store.RoleAdmin && currentUser.Role != store.RoleHost {
|
||||
if currentUser == nil {
|
||||
return nil, status.Errorf(codes.Unauthenticated, "user not authenticated")
|
||||
}
|
||||
if currentUser.ID != userID && currentUser.Role != store.RoleAdmin {
|
||||
return nil, status.Errorf(codes.PermissionDenied, "permission denied")
|
||||
}
|
||||
|
||||
|
|
@ -539,7 +545,7 @@ func (s *APIV1Service) ListPersonalAccessTokens(ctx context.Context, request *v1
|
|||
claims := auth.GetUserClaims(ctx)
|
||||
if claims == nil || claims.UserID != userID {
|
||||
currentUser, _ := s.fetchCurrentUser(ctx)
|
||||
if currentUser == nil || (currentUser.ID != userID && currentUser.Role != store.RoleHost && currentUser.Role != store.RoleAdmin) {
|
||||
if currentUser == nil || (currentUser.ID != userID && currentUser.Role != store.RoleAdmin) {
|
||||
return nil, status.Errorf(codes.PermissionDenied, "permission denied")
|
||||
}
|
||||
}
|
||||
|
|
@ -686,7 +692,7 @@ func (s *APIV1Service) ListUserWebhooks(ctx context.Context, request *v1pb.ListU
|
|||
if currentUser == nil {
|
||||
return nil, status.Errorf(codes.Unauthenticated, "user not authenticated")
|
||||
}
|
||||
if currentUser.ID != userID && currentUser.Role != store.RoleHost && currentUser.Role != store.RoleAdmin {
|
||||
if currentUser.ID != userID && currentUser.Role != store.RoleAdmin {
|
||||
return nil, status.Errorf(codes.PermissionDenied, "permission denied")
|
||||
}
|
||||
|
||||
|
|
@ -718,7 +724,7 @@ func (s *APIV1Service) CreateUserWebhook(ctx context.Context, request *v1pb.Crea
|
|||
if currentUser == nil {
|
||||
return nil, status.Errorf(codes.Unauthenticated, "user not authenticated")
|
||||
}
|
||||
if currentUser.ID != userID && currentUser.Role != store.RoleHost && currentUser.Role != store.RoleAdmin {
|
||||
if currentUser.ID != userID && currentUser.Role != store.RoleAdmin {
|
||||
return nil, status.Errorf(codes.PermissionDenied, "permission denied")
|
||||
}
|
||||
|
||||
|
|
@ -758,7 +764,7 @@ func (s *APIV1Service) UpdateUserWebhook(ctx context.Context, request *v1pb.Upda
|
|||
if currentUser == nil {
|
||||
return nil, status.Errorf(codes.Unauthenticated, "user not authenticated")
|
||||
}
|
||||
if currentUser.ID != userID && currentUser.Role != store.RoleHost && currentUser.Role != store.RoleAdmin {
|
||||
if currentUser.ID != userID && currentUser.Role != store.RoleAdmin {
|
||||
return nil, status.Errorf(codes.PermissionDenied, "permission denied")
|
||||
}
|
||||
|
||||
|
|
@ -830,7 +836,7 @@ func (s *APIV1Service) DeleteUserWebhook(ctx context.Context, request *v1pb.Dele
|
|||
if currentUser == nil {
|
||||
return nil, status.Errorf(codes.Unauthenticated, "user not authenticated")
|
||||
}
|
||||
if currentUser.ID != userID && currentUser.Role != store.RoleHost && currentUser.Role != store.RoleAdmin {
|
||||
if currentUser.ID != userID && currentUser.Role != store.RoleAdmin {
|
||||
return nil, status.Errorf(codes.PermissionDenied, "permission denied")
|
||||
}
|
||||
|
||||
|
|
@ -925,8 +931,6 @@ func convertUserFromStore(user *store.User) *v1pb.User {
|
|||
|
||||
func convertUserRoleFromStore(role store.Role) v1pb.User_Role {
|
||||
switch role {
|
||||
case store.RoleHost:
|
||||
return v1pb.User_HOST
|
||||
case store.RoleAdmin:
|
||||
return v1pb.User_ADMIN
|
||||
case store.RoleUser:
|
||||
|
|
@ -938,8 +942,6 @@ func convertUserRoleFromStore(role store.Role) v1pb.User_Role {
|
|||
|
||||
func convertUserRoleToStore(role v1pb.User_Role) store.Role {
|
||||
switch role {
|
||||
case v1pb.User_HOST:
|
||||
return store.RoleHost
|
||||
case v1pb.User_ADMIN:
|
||||
return store.RoleAdmin
|
||||
default:
|
||||
|
|
@ -1240,6 +1242,9 @@ func (s *APIV1Service) ListUserNotifications(ctx context.Context, request *v1pb.
|
|||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to get current user: %v", err)
|
||||
}
|
||||
if currentUser == nil {
|
||||
return nil, status.Errorf(codes.Unauthenticated, "user not authenticated")
|
||||
}
|
||||
if currentUser.ID != userID {
|
||||
return nil, status.Errorf(codes.PermissionDenied, "permission denied")
|
||||
}
|
||||
|
|
@ -1287,6 +1292,9 @@ func (s *APIV1Service) UpdateUserNotification(ctx context.Context, request *v1pb
|
|||
return nil, status.Errorf(codes.Internal, "failed to get current user: %v", err)
|
||||
}
|
||||
|
||||
if currentUser == nil {
|
||||
return nil, status.Errorf(codes.Unauthenticated, "user not authenticated")
|
||||
}
|
||||
// Verify ownership before updating
|
||||
inboxes, err := s.Store.ListInboxes(ctx, &store.FindInbox{
|
||||
ID: ¬ificationID,
|
||||
|
|
@ -1352,6 +1360,9 @@ func (s *APIV1Service) DeleteUserNotification(ctx context.Context, request *v1pb
|
|||
return nil, status.Errorf(codes.Internal, "failed to get current user: %v", err)
|
||||
}
|
||||
|
||||
if currentUser == nil {
|
||||
return nil, status.Errorf(codes.Unauthenticated, "user not authenticated")
|
||||
}
|
||||
// Verify ownership before deletion
|
||||
inboxes, err := s.Store.ListInboxes(ctx, &store.FindInbox{
|
||||
ID: ¬ificationID,
|
||||
|
|
|
|||
|
|
@ -59,7 +59,7 @@ func (s *APIV1Service) RegisterGateway(ctx context.Context, echoServer *echo.Ech
|
|||
ctx := r.Context()
|
||||
|
||||
// Get the RPC method name from context (set by grpc-gateway after routing)
|
||||
rpcMethod, _ := runtime.RPCMethod(ctx)
|
||||
rpcMethod, ok := runtime.RPCMethod(ctx)
|
||||
|
||||
// Extract credentials from HTTP headers
|
||||
authHeader := r.Header.Get("Authorization")
|
||||
|
|
@ -67,7 +67,8 @@ func (s *APIV1Service) RegisterGateway(ctx context.Context, echoServer *echo.Ech
|
|||
result := authenticator.Authenticate(ctx, authHeader)
|
||||
|
||||
// Enforce authentication for non-public methods
|
||||
if result == nil && !IsPublicMethod(rpcMethod) {
|
||||
// If rpcMethod cannot be determined, allow through, service layer will handle visibility checks
|
||||
if result == nil && ok && !IsPublicMethod(rpcMethod) {
|
||||
http.Error(w, `{"code": 16, "message": "authentication required"}`, http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
|
|
@ -125,7 +126,7 @@ func (s *APIV1Service) RegisterGateway(ctx context.Context, echoServer *echo.Ech
|
|||
gwGroup.Any("/file/*", handler)
|
||||
|
||||
// Connect handlers for browser clients (replaces grpc-web).
|
||||
logStacktraces := s.Profile.IsDev()
|
||||
logStacktraces := s.Profile.Demo
|
||||
connectInterceptors := connect.WithInterceptors(
|
||||
NewMetadataInterceptor(), // Convert HTTP headers to gRPC metadata first
|
||||
NewLoggingInterceptor(logStacktraces),
|
||||
|
|
|
|||
|
|
@ -26,13 +26,55 @@ import (
|
|||
"github.com/usememos/memos/store"
|
||||
)
|
||||
|
||||
// Constants for file serving configuration.
|
||||
const (
|
||||
// ThumbnailCacheFolder is the folder name where the thumbnail images are stored.
|
||||
// ThumbnailCacheFolder is the folder name where thumbnail images are stored.
|
||||
ThumbnailCacheFolder = ".thumbnail_cache"
|
||||
// thumbnailMaxSize is the maximum size in pixels for the largest dimension of the thumbnail image.
|
||||
|
||||
// thumbnailMaxSize is the maximum dimension (width or height) for thumbnails.
|
||||
thumbnailMaxSize = 600
|
||||
|
||||
// maxConcurrentThumbnails limits concurrent thumbnail generation to prevent memory exhaustion.
|
||||
maxConcurrentThumbnails = 3
|
||||
|
||||
// cacheMaxAge is the max-age value for Cache-Control headers (1 hour).
|
||||
cacheMaxAge = "public, max-age=3600"
|
||||
)
|
||||
|
||||
// xssUnsafeTypes contains MIME types that could execute scripts if served directly.
|
||||
// These are served as application/octet-stream to prevent XSS attacks.
|
||||
var xssUnsafeTypes = map[string]bool{
|
||||
"text/html": true,
|
||||
"text/javascript": true,
|
||||
"application/javascript": true,
|
||||
"application/x-javascript": true,
|
||||
"text/xml": true,
|
||||
"application/xml": true,
|
||||
"application/xhtml+xml": true,
|
||||
"image/svg+xml": true,
|
||||
}
|
||||
|
||||
// thumbnailSupportedTypes contains image MIME types that support thumbnail generation.
|
||||
var thumbnailSupportedTypes = map[string]bool{
|
||||
"image/png": true,
|
||||
"image/jpeg": true,
|
||||
"image/heic": true,
|
||||
"image/heif": true,
|
||||
"image/webp": true,
|
||||
}
|
||||
|
||||
// avatarAllowedTypes contains MIME types allowed for user avatars.
|
||||
var avatarAllowedTypes = map[string]bool{
|
||||
"image/png": true,
|
||||
"image/jpeg": true,
|
||||
"image/jpg": true,
|
||||
"image/gif": true,
|
||||
"image/webp": true,
|
||||
"image/heic": true,
|
||||
"image/heif": true,
|
||||
}
|
||||
|
||||
// SupportedThumbnailMimeTypes is the exported list of thumbnail-supported MIME types.
|
||||
var SupportedThumbnailMimeTypes = []string{
|
||||
"image/png",
|
||||
"image/jpeg",
|
||||
|
|
@ -41,15 +83,16 @@ var SupportedThumbnailMimeTypes = []string{
|
|||
"image/webp",
|
||||
}
|
||||
|
||||
// dataURIRegex parses data URI format: data:image/png;base64,iVBORw0KGgo...
|
||||
var dataURIRegex = regexp.MustCompile(`^data:(?P<type>[^;]+);base64,(?P<base64>.+)`)
|
||||
|
||||
// FileServerService handles HTTP file serving with proper range request support.
|
||||
// This service bypasses gRPC-Gateway to use native HTTP serving via http.ServeContent(),
|
||||
// which is required for Safari video/audio playback.
|
||||
type FileServerService struct {
|
||||
Profile *profile.Profile
|
||||
Store *store.Store
|
||||
authenticator *auth.Authenticator
|
||||
|
||||
// thumbnailSemaphore limits concurrent thumbnail generation to prevent memory exhaustion
|
||||
// thumbnailSemaphore limits concurrent thumbnail generation.
|
||||
thumbnailSemaphore *semaphore.Weighted
|
||||
}
|
||||
|
||||
|
|
@ -59,29 +102,27 @@ func NewFileServerService(profile *profile.Profile, store *store.Store, secret s
|
|||
Profile: profile,
|
||||
Store: store,
|
||||
authenticator: auth.NewAuthenticator(store, secret),
|
||||
thumbnailSemaphore: semaphore.NewWeighted(3), // Limit to 3 concurrent thumbnail generations
|
||||
thumbnailSemaphore: semaphore.NewWeighted(maxConcurrentThumbnails),
|
||||
}
|
||||
}
|
||||
|
||||
// RegisterRoutes registers HTTP file serving routes.
|
||||
func (s *FileServerService) RegisterRoutes(echoServer *echo.Echo) {
|
||||
fileGroup := echoServer.Group("/file")
|
||||
|
||||
// Serve attachment binary files
|
||||
fileGroup.GET("/attachments/:uid/:filename", s.serveAttachmentFile)
|
||||
|
||||
// Serve user avatar images
|
||||
fileGroup.GET("/users/:identifier/avatar", s.serveUserAvatar)
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// HTTP Handlers
|
||||
// =============================================================================
|
||||
|
||||
// serveAttachmentFile serves attachment binary content using native HTTP.
|
||||
// This properly handles range requests required by Safari for video/audio playback.
|
||||
func (s *FileServerService) serveAttachmentFile(c echo.Context) error {
|
||||
ctx := c.Request().Context()
|
||||
uid := c.Param("uid")
|
||||
thumbnail := c.QueryParam("thumbnail") == "true"
|
||||
wantThumbnail := c.QueryParam("thumbnail") == "true"
|
||||
|
||||
// Get attachment from database
|
||||
attachment, err := s.Store.GetAttachment(ctx, &store.FindAttachment{
|
||||
UID: &uid,
|
||||
GetBlob: true,
|
||||
|
|
@ -93,96 +134,25 @@ func (s *FileServerService) serveAttachmentFile(c echo.Context) error {
|
|||
return echo.NewHTTPError(http.StatusNotFound, "attachment not found")
|
||||
}
|
||||
|
||||
// Check permissions - verify memo visibility if attachment belongs to a memo
|
||||
if err := s.checkAttachmentPermission(ctx, c, attachment); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Get the binary content
|
||||
blob, err := s.getAttachmentBlob(attachment)
|
||||
if err != nil {
|
||||
return echo.NewHTTPError(http.StatusInternalServerError, "failed to get attachment blob").SetInternal(err)
|
||||
contentType := s.sanitizeContentType(attachment.Type)
|
||||
|
||||
// Stream video/audio to avoid loading entire file into memory.
|
||||
if isMediaType(attachment.Type) {
|
||||
return s.serveMediaStream(c, attachment, contentType)
|
||||
}
|
||||
|
||||
// Handle thumbnail requests for images
|
||||
if thumbnail && s.isImageType(attachment.Type) {
|
||||
thumbnailBlob, err := s.getOrGenerateThumbnail(ctx, attachment)
|
||||
if err != nil {
|
||||
// Log warning but fall back to original image
|
||||
c.Logger().Warnf("failed to get thumbnail: %v", err)
|
||||
} else {
|
||||
blob = thumbnailBlob
|
||||
}
|
||||
}
|
||||
|
||||
// Determine content type
|
||||
contentType := attachment.Type
|
||||
if strings.HasPrefix(contentType, "text/") {
|
||||
contentType += "; charset=utf-8"
|
||||
}
|
||||
// Prevent XSS attacks by serving potentially unsafe files as octet-stream
|
||||
unsafeTypes := []string{
|
||||
"text/html",
|
||||
"text/javascript",
|
||||
"application/javascript",
|
||||
"application/x-javascript",
|
||||
"text/xml",
|
||||
"application/xml",
|
||||
"application/xhtml+xml",
|
||||
"image/svg+xml",
|
||||
}
|
||||
for _, unsafeType := range unsafeTypes {
|
||||
if strings.EqualFold(contentType, unsafeType) {
|
||||
contentType = "application/octet-stream"
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// Set common headers
|
||||
c.Response().Header().Set("Content-Type", contentType)
|
||||
c.Response().Header().Set("Cache-Control", "public, max-age=3600")
|
||||
// Prevent MIME-type sniffing which could lead to XSS
|
||||
c.Response().Header().Set("X-Content-Type-Options", "nosniff")
|
||||
// Defense-in-depth: prevent embedding in frames and restrict content loading
|
||||
c.Response().Header().Set("X-Frame-Options", "DENY")
|
||||
c.Response().Header().Set("Content-Security-Policy", "default-src 'none'; style-src 'unsafe-inline';")
|
||||
// Support HDR/wide color gamut display for capable browsers
|
||||
if strings.HasPrefix(contentType, "image/") || strings.HasPrefix(contentType, "video/") {
|
||||
c.Response().Header().Set("Color-Gamut", "srgb, p3, rec2020")
|
||||
}
|
||||
|
||||
// Force download for non-media files to prevent XSS execution
|
||||
if !strings.HasPrefix(contentType, "image/") &&
|
||||
!strings.HasPrefix(contentType, "video/") &&
|
||||
!strings.HasPrefix(contentType, "audio/") &&
|
||||
contentType != "application/pdf" {
|
||||
c.Response().Header().Set("Content-Disposition", fmt.Sprintf("attachment; filename=%q", attachment.Filename))
|
||||
}
|
||||
|
||||
// For video/audio: Use http.ServeContent for automatic range request support
|
||||
// This is critical for Safari which REQUIRES range request support
|
||||
if strings.HasPrefix(contentType, "video/") || strings.HasPrefix(contentType, "audio/") {
|
||||
// ServeContent automatically handles:
|
||||
// - Range request parsing
|
||||
// - HTTP 206 Partial Content responses
|
||||
// - Content-Range headers
|
||||
// - Accept-Ranges: bytes header
|
||||
modTime := time.Unix(attachment.UpdatedTs, 0)
|
||||
http.ServeContent(c.Response(), c.Request(), attachment.Filename, modTime, bytes.NewReader(blob))
|
||||
return nil
|
||||
}
|
||||
|
||||
// For other files: Simple blob response
|
||||
return c.Blob(http.StatusOK, contentType, blob)
|
||||
return s.serveStaticFile(c, attachment, contentType, wantThumbnail)
|
||||
}
|
||||
|
||||
// serveUserAvatar serves user avatar images.
|
||||
// Supports both user ID and username as identifier.
|
||||
func (s *FileServerService) serveUserAvatar(c echo.Context) error {
|
||||
ctx := c.Request().Context()
|
||||
identifier := c.Param("identifier")
|
||||
|
||||
// Try to find user by ID or username
|
||||
user, err := s.getUserByIdentifier(ctx, identifier)
|
||||
if err != nil {
|
||||
return echo.NewHTTPError(http.StatusInternalServerError, "failed to get user").SetInternal(err)
|
||||
|
|
@ -194,79 +164,327 @@ func (s *FileServerService) serveUserAvatar(c echo.Context) error {
|
|||
return echo.NewHTTPError(http.StatusNotFound, "avatar not found")
|
||||
}
|
||||
|
||||
// Extract image info from data URI
|
||||
imageType, base64Data, err := s.extractImageInfo(user.AvatarURL)
|
||||
imageType, imageData, err := s.parseDataURI(user.AvatarURL)
|
||||
if err != nil {
|
||||
return echo.NewHTTPError(http.StatusInternalServerError, "failed to extract image info").SetInternal(err)
|
||||
return echo.NewHTTPError(http.StatusInternalServerError, "failed to parse avatar data").SetInternal(err)
|
||||
}
|
||||
|
||||
// Validate avatar MIME type to prevent XSS
|
||||
// Supports standard formats and HDR-capable formats
|
||||
allowedAvatarTypes := map[string]bool{
|
||||
"image/png": true,
|
||||
"image/jpeg": true,
|
||||
"image/jpg": true,
|
||||
"image/gif": true,
|
||||
"image/webp": true,
|
||||
"image/heic": true,
|
||||
"image/heif": true,
|
||||
}
|
||||
if !allowedAvatarTypes[imageType] {
|
||||
if !avatarAllowedTypes[imageType] {
|
||||
return echo.NewHTTPError(http.StatusBadRequest, "invalid avatar image type")
|
||||
}
|
||||
|
||||
// Decode base64 data
|
||||
imageData, err := base64.StdEncoding.DecodeString(base64Data)
|
||||
if err != nil {
|
||||
return echo.NewHTTPError(http.StatusInternalServerError, "failed to decode image data").SetInternal(err)
|
||||
}
|
||||
|
||||
// Set cache headers for avatars
|
||||
c.Response().Header().Set("Content-Type", imageType)
|
||||
c.Response().Header().Set("Cache-Control", "public, max-age=3600")
|
||||
c.Response().Header().Set("X-Content-Type-Options", "nosniff")
|
||||
// Defense-in-depth: prevent embedding in frames
|
||||
c.Response().Header().Set("X-Frame-Options", "DENY")
|
||||
c.Response().Header().Set("Content-Security-Policy", "default-src 'none'; style-src 'unsafe-inline';")
|
||||
setSecurityHeaders(c)
|
||||
c.Response().Header().Set(echo.HeaderContentType, imageType)
|
||||
c.Response().Header().Set(echo.HeaderCacheControl, cacheMaxAge)
|
||||
|
||||
return c.Blob(http.StatusOK, imageType, imageData)
|
||||
}
|
||||
|
||||
// getUserByIdentifier finds a user by either ID or username.
|
||||
func (s *FileServerService) getUserByIdentifier(ctx context.Context, identifier string) (*store.User, error) {
|
||||
// Try to parse as ID first
|
||||
if userID, err := util.ConvertStringToInt32(identifier); err == nil {
|
||||
return s.Store.GetUser(ctx, &store.FindUser{ID: &userID})
|
||||
}
|
||||
// =============================================================================
|
||||
// File Serving Methods
|
||||
// =============================================================================
|
||||
|
||||
// Otherwise, treat as username
|
||||
return s.Store.GetUser(ctx, &store.FindUser{Username: &identifier})
|
||||
// serveMediaStream serves video/audio files using streaming to avoid memory exhaustion.
|
||||
func (s *FileServerService) serveMediaStream(c echo.Context, attachment *store.Attachment, contentType string) error {
|
||||
setSecurityHeaders(c)
|
||||
setMediaHeaders(c, contentType, attachment.Type)
|
||||
|
||||
switch attachment.StorageType {
|
||||
case storepb.AttachmentStorageType_LOCAL:
|
||||
filePath, err := s.resolveLocalPath(attachment.Reference)
|
||||
if err != nil {
|
||||
return echo.NewHTTPError(http.StatusInternalServerError, "failed to resolve file path").SetInternal(err)
|
||||
}
|
||||
http.ServeFile(c.Response(), c.Request(), filePath)
|
||||
return nil
|
||||
|
||||
case storepb.AttachmentStorageType_S3:
|
||||
presignURL, err := s.getS3PresignedURL(c.Request().Context(), attachment)
|
||||
if err != nil {
|
||||
return echo.NewHTTPError(http.StatusInternalServerError, "failed to generate presigned URL").SetInternal(err)
|
||||
}
|
||||
return c.Redirect(http.StatusTemporaryRedirect, presignURL)
|
||||
|
||||
default:
|
||||
// Database storage fallback.
|
||||
modTime := time.Unix(attachment.UpdatedTs, 0)
|
||||
http.ServeContent(c.Response(), c.Request(), attachment.Filename, modTime, bytes.NewReader(attachment.Blob))
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// extractImageInfo extracts image type and base64 data from a data URI.
|
||||
// Data URI format: data:image/png;base64,iVBORw0KGgo...
|
||||
func (*FileServerService) extractImageInfo(dataURI string) (string, string, error) {
|
||||
dataURIRegex := regexp.MustCompile(`^data:(?P<type>.+);base64,(?P<base64>.+)`)
|
||||
matches := dataURIRegex.FindStringSubmatch(dataURI)
|
||||
if len(matches) != 3 {
|
||||
return "", "", errors.New("invalid data URI format")
|
||||
// serveStaticFile serves non-streaming files (images, documents, etc.).
|
||||
func (s *FileServerService) serveStaticFile(c echo.Context, attachment *store.Attachment, contentType string, wantThumbnail bool) error {
|
||||
blob, err := s.getAttachmentBlob(attachment)
|
||||
if err != nil {
|
||||
return echo.NewHTTPError(http.StatusInternalServerError, "failed to get attachment blob").SetInternal(err)
|
||||
}
|
||||
imageType := matches[1]
|
||||
base64Data := matches[2]
|
||||
return imageType, base64Data, nil
|
||||
|
||||
// Generate thumbnail for supported image types.
|
||||
if wantThumbnail && thumbnailSupportedTypes[attachment.Type] {
|
||||
if thumbnailBlob, err := s.getOrGenerateThumbnail(c.Request().Context(), attachment); err != nil {
|
||||
c.Logger().Warnf("failed to get thumbnail: %v", err)
|
||||
} else {
|
||||
blob = thumbnailBlob
|
||||
}
|
||||
}
|
||||
|
||||
setSecurityHeaders(c)
|
||||
setMediaHeaders(c, contentType, attachment.Type)
|
||||
|
||||
// Force download for non-media files to prevent XSS execution.
|
||||
if !strings.HasPrefix(contentType, "image/") && contentType != "application/pdf" {
|
||||
c.Response().Header().Set(echo.HeaderContentDisposition, fmt.Sprintf("attachment; filename=%q", attachment.Filename))
|
||||
}
|
||||
|
||||
return c.Blob(http.StatusOK, contentType, blob)
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// Storage Operations
|
||||
// =============================================================================
|
||||
|
||||
// getAttachmentBlob retrieves the binary content of an attachment from storage.
|
||||
func (s *FileServerService) getAttachmentBlob(attachment *store.Attachment) ([]byte, error) {
|
||||
switch attachment.StorageType {
|
||||
case storepb.AttachmentStorageType_LOCAL:
|
||||
return s.readLocalFile(attachment.Reference)
|
||||
|
||||
case storepb.AttachmentStorageType_S3:
|
||||
return s.downloadFromS3(attachment)
|
||||
|
||||
default:
|
||||
return attachment.Blob, nil
|
||||
}
|
||||
}
|
||||
|
||||
// getAttachmentReader returns a reader for streaming attachment content.
|
||||
func (s *FileServerService) getAttachmentReader(attachment *store.Attachment) (io.ReadCloser, error) {
|
||||
switch attachment.StorageType {
|
||||
case storepb.AttachmentStorageType_LOCAL:
|
||||
filePath, err := s.resolveLocalPath(attachment.Reference)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
file, err := os.Open(filePath)
|
||||
if err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
return nil, errors.Wrap(err, "file not found")
|
||||
}
|
||||
return nil, errors.Wrap(err, "failed to open file")
|
||||
}
|
||||
return file, nil
|
||||
|
||||
case storepb.AttachmentStorageType_S3:
|
||||
s3Client, s3Object, err := s.createS3Client(attachment)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
reader, err := s3Client.GetObjectStream(context.Background(), s3Object.Key)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "failed to stream from S3")
|
||||
}
|
||||
return reader, nil
|
||||
|
||||
default:
|
||||
return io.NopCloser(bytes.NewReader(attachment.Blob)), nil
|
||||
}
|
||||
}
|
||||
|
||||
// resolveLocalPath converts a storage reference to an absolute file path.
|
||||
func (s *FileServerService) resolveLocalPath(reference string) (string, error) {
|
||||
filePath := filepath.FromSlash(reference)
|
||||
if !filepath.IsAbs(filePath) {
|
||||
filePath = filepath.Join(s.Profile.Data, filePath)
|
||||
}
|
||||
return filePath, nil
|
||||
}
|
||||
|
||||
// readLocalFile reads the entire contents of a local file.
|
||||
func (s *FileServerService) readLocalFile(reference string) ([]byte, error) {
|
||||
filePath, err := s.resolveLocalPath(reference)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
file, err := os.Open(filePath)
|
||||
if err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
return nil, errors.Wrap(err, "file not found")
|
||||
}
|
||||
return nil, errors.Wrap(err, "failed to open file")
|
||||
}
|
||||
defer file.Close()
|
||||
|
||||
blob, err := io.ReadAll(file)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "failed to read file")
|
||||
}
|
||||
return blob, nil
|
||||
}
|
||||
|
||||
// createS3Client creates an S3 client from attachment payload.
|
||||
func (*FileServerService) createS3Client(attachment *store.Attachment) (*s3.Client, *storepb.AttachmentPayload_S3Object, error) {
|
||||
if attachment.Payload == nil {
|
||||
return nil, nil, errors.New("attachment payload is missing")
|
||||
}
|
||||
s3Object := attachment.Payload.GetS3Object()
|
||||
if s3Object == nil {
|
||||
return nil, nil, errors.New("S3 object payload is missing")
|
||||
}
|
||||
if s3Object.S3Config == nil {
|
||||
return nil, nil, errors.New("S3 config is missing")
|
||||
}
|
||||
if s3Object.Key == "" {
|
||||
return nil, nil, errors.New("S3 object key is missing")
|
||||
}
|
||||
|
||||
client, err := s3.NewClient(context.Background(), s3Object.S3Config)
|
||||
if err != nil {
|
||||
return nil, nil, errors.Wrap(err, "failed to create S3 client")
|
||||
}
|
||||
return client, s3Object, nil
|
||||
}
|
||||
|
||||
// downloadFromS3 downloads the entire object from S3.
|
||||
func (s *FileServerService) downloadFromS3(attachment *store.Attachment) ([]byte, error) {
|
||||
client, s3Object, err := s.createS3Client(attachment)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
blob, err := client.GetObject(context.Background(), s3Object.Key)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "failed to download from S3")
|
||||
}
|
||||
return blob, nil
|
||||
}
|
||||
|
||||
// getS3PresignedURL generates a presigned URL for direct S3 access.
|
||||
func (s *FileServerService) getS3PresignedURL(ctx context.Context, attachment *store.Attachment) (string, error) {
|
||||
client, s3Object, err := s.createS3Client(attachment)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
url, err := client.PresignGetObject(ctx, s3Object.Key)
|
||||
if err != nil {
|
||||
return "", errors.Wrap(err, "failed to presign URL")
|
||||
}
|
||||
return url, nil
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// Thumbnail Generation
|
||||
// =============================================================================
|
||||
|
||||
// getOrGenerateThumbnail returns the thumbnail image of the attachment.
|
||||
// Uses semaphore to limit concurrent thumbnail generation and prevent memory exhaustion.
|
||||
func (s *FileServerService) getOrGenerateThumbnail(ctx context.Context, attachment *store.Attachment) ([]byte, error) {
|
||||
thumbnailPath, err := s.getThumbnailPath(attachment)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Fast path: return cached thumbnail if exists.
|
||||
if blob, err := s.readCachedThumbnail(thumbnailPath); err == nil {
|
||||
return blob, nil
|
||||
}
|
||||
|
||||
// Acquire semaphore to limit concurrent generation.
|
||||
if err := s.thumbnailSemaphore.Acquire(ctx, 1); err != nil {
|
||||
return nil, errors.Wrap(err, "failed to acquire semaphore")
|
||||
}
|
||||
defer s.thumbnailSemaphore.Release(1)
|
||||
|
||||
// Double-check after acquiring semaphore (another goroutine may have generated it).
|
||||
if blob, err := s.readCachedThumbnail(thumbnailPath); err == nil {
|
||||
return blob, nil
|
||||
}
|
||||
|
||||
return s.generateThumbnail(attachment, thumbnailPath)
|
||||
}
|
||||
|
||||
// getThumbnailPath returns the file path for a cached thumbnail.
|
||||
func (s *FileServerService) getThumbnailPath(attachment *store.Attachment) (string, error) {
|
||||
cacheFolder := filepath.Join(s.Profile.Data, ThumbnailCacheFolder)
|
||||
if err := os.MkdirAll(cacheFolder, os.ModePerm); err != nil {
|
||||
return "", errors.Wrap(err, "failed to create thumbnail cache folder")
|
||||
}
|
||||
filename := fmt.Sprintf("%d%s", attachment.ID, filepath.Ext(attachment.Filename))
|
||||
return filepath.Join(cacheFolder, filename), nil
|
||||
}
|
||||
|
||||
// readCachedThumbnail reads a thumbnail from the cache directory.
|
||||
func (*FileServerService) readCachedThumbnail(path string) ([]byte, error) {
|
||||
file, err := os.Open(path)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer file.Close()
|
||||
return io.ReadAll(file)
|
||||
}
|
||||
|
||||
// generateThumbnail creates a new thumbnail and saves it to disk.
|
||||
func (s *FileServerService) generateThumbnail(attachment *store.Attachment, thumbnailPath string) ([]byte, error) {
|
||||
reader, err := s.getAttachmentReader(attachment)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "failed to get attachment reader")
|
||||
}
|
||||
defer reader.Close()
|
||||
|
||||
img, err := imaging.Decode(reader, imaging.AutoOrientation(true))
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "failed to decode image")
|
||||
}
|
||||
|
||||
width, height := img.Bounds().Dx(), img.Bounds().Dy()
|
||||
thumbnailWidth, thumbnailHeight := calculateThumbnailDimensions(width, height)
|
||||
|
||||
thumbnailImage := imaging.Resize(img, thumbnailWidth, thumbnailHeight, imaging.Lanczos)
|
||||
|
||||
if err := imaging.Save(thumbnailImage, thumbnailPath); err != nil {
|
||||
return nil, errors.Wrap(err, "failed to save thumbnail")
|
||||
}
|
||||
|
||||
return s.readCachedThumbnail(thumbnailPath)
|
||||
}
|
||||
|
||||
// calculateThumbnailDimensions calculates the target dimensions for a thumbnail.
|
||||
// The largest dimension is constrained to thumbnailMaxSize while maintaining aspect ratio.
|
||||
// Small images are not enlarged.
|
||||
func calculateThumbnailDimensions(width, height int) (int, int) {
|
||||
if max(width, height) <= thumbnailMaxSize {
|
||||
return width, height
|
||||
}
|
||||
if width >= height {
|
||||
return thumbnailMaxSize, 0 // Landscape: constrain width.
|
||||
}
|
||||
return 0, thumbnailMaxSize // Portrait: constrain height.
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// Authentication & Authorization
|
||||
// =============================================================================
|
||||
|
||||
// checkAttachmentPermission verifies the user has permission to access the attachment.
|
||||
func (s *FileServerService) checkAttachmentPermission(ctx context.Context, c echo.Context, attachment *store.Attachment) error {
|
||||
// If attachment is not linked to a memo, allow access
|
||||
// For unlinked attachments, only the creator can access.
|
||||
if attachment.MemoID == nil {
|
||||
user, err := s.getCurrentUser(ctx, c)
|
||||
if err != nil {
|
||||
return echo.NewHTTPError(http.StatusInternalServerError, "failed to get current user").SetInternal(err)
|
||||
}
|
||||
if user == nil {
|
||||
return echo.NewHTTPError(http.StatusUnauthorized, "unauthorized access")
|
||||
}
|
||||
if user.ID != attachment.CreatorID && user.Role != store.RoleAdmin {
|
||||
return echo.NewHTTPError(http.StatusForbidden, "forbidden access")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Check memo visibility
|
||||
memo, err := s.Store.GetMemo(ctx, &store.FindMemo{
|
||||
ID: attachment.MemoID,
|
||||
})
|
||||
memo, err := s.Store.GetMemo(ctx, &store.FindMemo{ID: attachment.MemoID})
|
||||
if err != nil {
|
||||
return echo.NewHTTPError(http.StatusInternalServerError, "failed to find memo").SetInternal(err)
|
||||
}
|
||||
|
|
@ -274,12 +492,10 @@ func (s *FileServerService) checkAttachmentPermission(ctx context.Context, c ech
|
|||
return echo.NewHTTPError(http.StatusNotFound, "memo not found")
|
||||
}
|
||||
|
||||
// Public memos are accessible to everyone
|
||||
if memo.Visibility == store.Public {
|
||||
return nil
|
||||
}
|
||||
|
||||
// For non-public memos, check authentication
|
||||
user, err := s.getCurrentUser(ctx, c)
|
||||
if err != nil {
|
||||
return echo.NewHTTPError(http.StatusInternalServerError, "failed to get current user").SetInternal(err)
|
||||
|
|
@ -288,278 +504,132 @@ func (s *FileServerService) checkAttachmentPermission(ctx context.Context, c ech
|
|||
return echo.NewHTTPError(http.StatusUnauthorized, "unauthorized access")
|
||||
}
|
||||
|
||||
// Private memos can only be accessed by the creator
|
||||
if memo.Visibility == store.Private && user.ID != attachment.CreatorID {
|
||||
if memo.Visibility == store.Private && user.ID != memo.CreatorID && user.Role != store.RoleAdmin {
|
||||
return echo.NewHTTPError(http.StatusForbidden, "forbidden access")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// getCurrentUser retrieves the current authenticated user from the Echo context.
|
||||
// getCurrentUser retrieves the current authenticated user from the request.
|
||||
// Authentication priority: Bearer token (Access Token V2 or PAT) > Refresh token cookie.
|
||||
// Uses the shared Authenticator for consistent authentication logic.
|
||||
func (s *FileServerService) getCurrentUser(ctx context.Context, c echo.Context) (*store.User, error) {
|
||||
// Try Bearer token authentication first
|
||||
authHeader := c.Request().Header.Get("Authorization")
|
||||
if authHeader != "" {
|
||||
token := auth.ExtractBearerToken(authHeader)
|
||||
if token != "" {
|
||||
// Try Access Token V2 (stateless)
|
||||
if !strings.HasPrefix(token, auth.PersonalAccessTokenPrefix) {
|
||||
claims, err := s.authenticator.AuthenticateByAccessTokenV2(token)
|
||||
if err == nil && claims != nil {
|
||||
// Get user from claims
|
||||
user, err := s.Store.GetUser(ctx, &store.FindUser{ID: &claims.UserID})
|
||||
if err == nil && user != nil {
|
||||
return user, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Try PAT
|
||||
if strings.HasPrefix(token, auth.PersonalAccessTokenPrefix) {
|
||||
user, _, err := s.authenticator.AuthenticateByPAT(ctx, token)
|
||||
if err == nil && user != nil {
|
||||
return user, nil
|
||||
}
|
||||
}
|
||||
// Try Bearer token authentication.
|
||||
if authHeader := c.Request().Header.Get(echo.HeaderAuthorization); authHeader != "" {
|
||||
if user, err := s.authenticateByBearerToken(ctx, authHeader); err == nil && user != nil {
|
||||
return user, nil
|
||||
}
|
||||
}
|
||||
|
||||
// Fallback: Try refresh token cookie authentication
|
||||
// This allows protected attachments to load even when access token has expired,
|
||||
// as long as the user has a valid refresh token cookie.
|
||||
cookieHeader := c.Request().Header.Get("Cookie")
|
||||
if cookieHeader != "" {
|
||||
refreshToken := auth.ExtractRefreshTokenFromCookie(cookieHeader)
|
||||
if refreshToken != "" {
|
||||
user, _, err := s.authenticator.AuthenticateByRefreshToken(ctx, refreshToken)
|
||||
if err == nil && user != nil {
|
||||
return user, nil
|
||||
}
|
||||
// Fallback: Try refresh token cookie.
|
||||
if cookieHeader := c.Request().Header.Get("Cookie"); cookieHeader != "" {
|
||||
if user, err := s.authenticateByRefreshToken(ctx, cookieHeader); err == nil && user != nil {
|
||||
return user, nil
|
||||
}
|
||||
}
|
||||
|
||||
// No valid authentication found
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// isImageType checks if the mime type is an image that supports thumbnails.
|
||||
// Supports standard formats (PNG, JPEG) and HDR-capable formats (HEIC, HEIF, WebP).
|
||||
func (*FileServerService) isImageType(mimeType string) bool {
|
||||
supportedTypes := map[string]bool{
|
||||
"image/png": true,
|
||||
"image/jpeg": true,
|
||||
"image/heic": true,
|
||||
"image/heif": true,
|
||||
"image/webp": true,
|
||||
// authenticateByBearerToken authenticates using Authorization header.
|
||||
func (s *FileServerService) authenticateByBearerToken(ctx context.Context, authHeader string) (*store.User, error) {
|
||||
token := auth.ExtractBearerToken(authHeader)
|
||||
if token == "" {
|
||||
return nil, nil
|
||||
}
|
||||
return supportedTypes[mimeType]
|
||||
|
||||
// Try Access Token V2 (stateless JWT).
|
||||
if !strings.HasPrefix(token, auth.PersonalAccessTokenPrefix) {
|
||||
claims, err := s.authenticator.AuthenticateByAccessTokenV2(token)
|
||||
if err == nil && claims != nil {
|
||||
return s.Store.GetUser(ctx, &store.FindUser{ID: &claims.UserID})
|
||||
}
|
||||
}
|
||||
|
||||
// Try Personal Access Token (stateful).
|
||||
if strings.HasPrefix(token, auth.PersonalAccessTokenPrefix) {
|
||||
user, _, err := s.authenticator.AuthenticateByPAT(ctx, token)
|
||||
if err == nil {
|
||||
return user, nil
|
||||
}
|
||||
}
|
||||
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// getAttachmentReader returns a reader for the attachment content.
|
||||
func (s *FileServerService) getAttachmentReader(attachment *store.Attachment) (io.ReadCloser, error) {
|
||||
// For local storage, read the file from the local disk.
|
||||
if attachment.StorageType == storepb.AttachmentStorageType_LOCAL {
|
||||
attachmentPath := filepath.FromSlash(attachment.Reference)
|
||||
if !filepath.IsAbs(attachmentPath) {
|
||||
attachmentPath = filepath.Join(s.Profile.Data, attachmentPath)
|
||||
}
|
||||
|
||||
file, err := os.Open(attachmentPath)
|
||||
if err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
return nil, errors.Wrap(err, "file not found")
|
||||
}
|
||||
return nil, errors.Wrap(err, "failed to open the file")
|
||||
}
|
||||
return file, nil
|
||||
// authenticateByRefreshToken authenticates using refresh token cookie.
|
||||
func (s *FileServerService) authenticateByRefreshToken(ctx context.Context, cookieHeader string) (*store.User, error) {
|
||||
refreshToken := auth.ExtractRefreshTokenFromCookie(cookieHeader)
|
||||
if refreshToken == "" {
|
||||
return nil, nil
|
||||
}
|
||||
// For S3 storage, download the file from S3.
|
||||
if attachment.StorageType == storepb.AttachmentStorageType_S3 {
|
||||
if attachment.Payload == nil {
|
||||
return nil, errors.New("attachment payload is missing")
|
||||
}
|
||||
s3Object := attachment.Payload.GetS3Object()
|
||||
if s3Object == nil {
|
||||
return nil, errors.New("S3 object payload is missing")
|
||||
}
|
||||
if s3Object.S3Config == nil {
|
||||
return nil, errors.New("S3 config is missing")
|
||||
}
|
||||
if s3Object.Key == "" {
|
||||
return nil, errors.New("S3 object key is missing")
|
||||
}
|
||||
|
||||
s3Client, err := s3.NewClient(context.Background(), s3Object.S3Config)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "failed to create S3 client")
|
||||
}
|
||||
|
||||
reader, err := s3Client.GetObjectStream(context.Background(), s3Object.Key)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "failed to get object from S3")
|
||||
}
|
||||
return reader, nil
|
||||
}
|
||||
// For database storage, return the blob from the database.
|
||||
return io.NopCloser(bytes.NewReader(attachment.Blob)), nil
|
||||
user, _, err := s.authenticator.AuthenticateByRefreshToken(ctx, refreshToken)
|
||||
return user, err
|
||||
}
|
||||
|
||||
// getAttachmentBlob retrieves the binary content of an attachment from storage.
|
||||
func (s *FileServerService) getAttachmentBlob(attachment *store.Attachment) ([]byte, error) {
|
||||
// For local storage, read the file from the local disk.
|
||||
if attachment.StorageType == storepb.AttachmentStorageType_LOCAL {
|
||||
attachmentPath := filepath.FromSlash(attachment.Reference)
|
||||
if !filepath.IsAbs(attachmentPath) {
|
||||
attachmentPath = filepath.Join(s.Profile.Data, attachmentPath)
|
||||
}
|
||||
|
||||
file, err := os.Open(attachmentPath)
|
||||
if err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
return nil, errors.Wrap(err, "file not found")
|
||||
}
|
||||
return nil, errors.Wrap(err, "failed to open the file")
|
||||
}
|
||||
defer file.Close()
|
||||
blob, err := io.ReadAll(file)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "failed to read the file")
|
||||
}
|
||||
return blob, nil
|
||||
// getUserByIdentifier finds a user by either ID or username.
|
||||
func (s *FileServerService) getUserByIdentifier(ctx context.Context, identifier string) (*store.User, error) {
|
||||
if userID, err := util.ConvertStringToInt32(identifier); err == nil {
|
||||
return s.Store.GetUser(ctx, &store.FindUser{ID: &userID})
|
||||
}
|
||||
// For S3 storage, download the file from S3.
|
||||
if attachment.StorageType == storepb.AttachmentStorageType_S3 {
|
||||
if attachment.Payload == nil {
|
||||
return nil, errors.New("attachment payload is missing")
|
||||
}
|
||||
s3Object := attachment.Payload.GetS3Object()
|
||||
if s3Object == nil {
|
||||
return nil, errors.New("S3 object payload is missing")
|
||||
}
|
||||
if s3Object.S3Config == nil {
|
||||
return nil, errors.New("S3 config is missing")
|
||||
}
|
||||
if s3Object.Key == "" {
|
||||
return nil, errors.New("S3 object key is missing")
|
||||
}
|
||||
|
||||
s3Client, err := s3.NewClient(context.Background(), s3Object.S3Config)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "failed to create S3 client")
|
||||
}
|
||||
|
||||
blob, err := s3Client.GetObject(context.Background(), s3Object.Key)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "failed to get object from S3")
|
||||
}
|
||||
return blob, nil
|
||||
}
|
||||
// For database storage, return the blob from the database.
|
||||
return attachment.Blob, nil
|
||||
return s.Store.GetUser(ctx, &store.FindUser{Username: &identifier})
|
||||
}
|
||||
|
||||
// getOrGenerateThumbnail returns the thumbnail image of the attachment.
|
||||
// Uses semaphore to limit concurrent thumbnail generation and prevent memory exhaustion.
|
||||
func (s *FileServerService) getOrGenerateThumbnail(ctx context.Context, attachment *store.Attachment) ([]byte, error) {
|
||||
thumbnailCacheFolder := filepath.Join(s.Profile.Data, ThumbnailCacheFolder)
|
||||
if err := os.MkdirAll(thumbnailCacheFolder, os.ModePerm); err != nil {
|
||||
return nil, errors.Wrap(err, "failed to create thumbnail cache folder")
|
||||
}
|
||||
filePath := filepath.Join(thumbnailCacheFolder, fmt.Sprintf("%d%s", attachment.ID, filepath.Ext(attachment.Filename)))
|
||||
// =============================================================================
|
||||
// Helper Functions
|
||||
// =============================================================================
|
||||
|
||||
// Check if thumbnail already exists
|
||||
if _, err := os.Stat(filePath); err == nil {
|
||||
// Thumbnail exists, read and return it
|
||||
thumbnailFile, err := os.Open(filePath)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "failed to open thumbnail file")
|
||||
}
|
||||
defer thumbnailFile.Close()
|
||||
blob, err := io.ReadAll(thumbnailFile)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "failed to read thumbnail file")
|
||||
}
|
||||
return blob, nil
|
||||
} else if !os.IsNotExist(err) {
|
||||
return nil, errors.Wrap(err, "failed to check thumbnail image stat")
|
||||
// sanitizeContentType converts potentially dangerous MIME types to safe alternatives.
|
||||
func (*FileServerService) sanitizeContentType(mimeType string) string {
|
||||
contentType := mimeType
|
||||
if strings.HasPrefix(contentType, "text/") {
|
||||
contentType += "; charset=utf-8"
|
||||
}
|
||||
|
||||
// Thumbnail doesn't exist, acquire semaphore to limit concurrent generation
|
||||
if err := s.thumbnailSemaphore.Acquire(ctx, 1); err != nil {
|
||||
return nil, errors.Wrap(err, "failed to acquire thumbnail generation semaphore")
|
||||
// Normalize for case-insensitive lookup.
|
||||
if xssUnsafeTypes[strings.ToLower(mimeType)] {
|
||||
return "application/octet-stream"
|
||||
}
|
||||
return contentType
|
||||
}
|
||||
|
||||
// parseDataURI extracts MIME type and decoded data from a data URI.
|
||||
func (*FileServerService) parseDataURI(dataURI string) (string, []byte, error) {
|
||||
matches := dataURIRegex.FindStringSubmatch(dataURI)
|
||||
if len(matches) != 3 {
|
||||
return "", nil, errors.New("invalid data URI format")
|
||||
}
|
||||
|
||||
imageType := matches[1]
|
||||
imageData, err := base64.StdEncoding.DecodeString(matches[2])
|
||||
if err != nil {
|
||||
return "", nil, errors.Wrap(err, "failed to decode base64 data")
|
||||
}
|
||||
|
||||
return imageType, imageData, nil
|
||||
}
|
||||
|
||||
// isMediaType checks if the MIME type is video or audio.
|
||||
func isMediaType(mimeType string) bool {
|
||||
return strings.HasPrefix(mimeType, "video/") || strings.HasPrefix(mimeType, "audio/")
|
||||
}
|
||||
|
||||
// setSecurityHeaders sets common security headers for all responses.
|
||||
func setSecurityHeaders(c echo.Context) {
|
||||
h := c.Response().Header()
|
||||
h.Set("X-Content-Type-Options", "nosniff")
|
||||
h.Set("X-Frame-Options", "DENY")
|
||||
h.Set("Content-Security-Policy", "default-src 'none'; style-src 'unsafe-inline';")
|
||||
}
|
||||
|
||||
// setMediaHeaders sets headers for media file responses.
|
||||
func setMediaHeaders(c echo.Context, contentType, originalType string) {
|
||||
h := c.Response().Header()
|
||||
h.Set(echo.HeaderContentType, contentType)
|
||||
h.Set(echo.HeaderCacheControl, cacheMaxAge)
|
||||
|
||||
// Support HDR/wide color gamut for images and videos.
|
||||
if strings.HasPrefix(originalType, "image/") || strings.HasPrefix(originalType, "video/") {
|
||||
h.Set("Color-Gamut", "srgb, p3, rec2020")
|
||||
}
|
||||
defer s.thumbnailSemaphore.Release(1)
|
||||
|
||||
// Double-check if thumbnail was created while waiting for semaphore
|
||||
if _, err := os.Stat(filePath); err == nil {
|
||||
thumbnailFile, err := os.Open(filePath)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "failed to open thumbnail file")
|
||||
}
|
||||
defer thumbnailFile.Close()
|
||||
blob, err := io.ReadAll(thumbnailFile)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "failed to read thumbnail file")
|
||||
}
|
||||
return blob, nil
|
||||
}
|
||||
|
||||
// Generate the thumbnail
|
||||
reader, err := s.getAttachmentReader(attachment)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "failed to get attachment reader")
|
||||
}
|
||||
defer reader.Close()
|
||||
|
||||
// Decode image - this is memory intensive
|
||||
img, err := imaging.Decode(reader, imaging.AutoOrientation(true))
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "failed to decode thumbnail image")
|
||||
}
|
||||
|
||||
// The largest dimension is set to thumbnailMaxSize and the smaller dimension is scaled proportionally.
|
||||
// Small images are not enlarged.
|
||||
width := img.Bounds().Dx()
|
||||
height := img.Bounds().Dy()
|
||||
var thumbnailWidth, thumbnailHeight int
|
||||
|
||||
// Only resize if the image is larger than thumbnailMaxSize
|
||||
if max(width, height) > thumbnailMaxSize {
|
||||
if width >= height {
|
||||
// Landscape or square - constrain width, maintain aspect ratio for height
|
||||
thumbnailWidth = thumbnailMaxSize
|
||||
thumbnailHeight = 0
|
||||
} else {
|
||||
// Portrait - constrain height, maintain aspect ratio for width
|
||||
thumbnailWidth = 0
|
||||
thumbnailHeight = thumbnailMaxSize
|
||||
}
|
||||
} else {
|
||||
// Keep original dimensions for small images
|
||||
thumbnailWidth = width
|
||||
thumbnailHeight = height
|
||||
}
|
||||
|
||||
// Resize the image to the calculated dimensions.
|
||||
thumbnailImage := imaging.Resize(img, thumbnailWidth, thumbnailHeight, imaging.Lanczos)
|
||||
|
||||
// Save thumbnail to disk
|
||||
if err := imaging.Save(thumbnailImage, filePath); err != nil {
|
||||
return nil, errors.Wrap(err, "failed to save thumbnail file")
|
||||
}
|
||||
|
||||
// Read the saved thumbnail and return it
|
||||
thumbnailFile, err := os.Open(filePath)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "failed to open thumbnail file")
|
||||
}
|
||||
defer thumbnailFile.Close()
|
||||
thumbnailBlob, err := io.ReadAll(thumbnailFile)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "failed to read thumbnail file")
|
||||
}
|
||||
return thumbnailBlob, nil
|
||||
}
|
||||
|
|
|
|||
|
|
@ -52,7 +52,7 @@ func NewServer(ctx context.Context, profile *profile.Profile, store *store.Store
|
|||
}
|
||||
|
||||
secret := "usememos"
|
||||
if profile.Mode == "prod" {
|
||||
if !profile.Demo {
|
||||
secret = instanceBasicSetting.SecretKey
|
||||
}
|
||||
s.Secret = secret
|
||||
|
|
|
|||
|
|
@ -161,20 +161,20 @@ func (c *Cache) Delete(_ context.Context, key string) {
|
|||
|
||||
// Clear removes all values from the cache.
|
||||
func (c *Cache) Clear(_ context.Context) {
|
||||
if c.config.OnEviction != nil {
|
||||
c.data.Range(func(key, value any) bool {
|
||||
count := 0
|
||||
c.data.Range(func(key, value any) bool {
|
||||
if c.config.OnEviction != nil {
|
||||
itm, ok := value.(item)
|
||||
if !ok {
|
||||
return true
|
||||
if ok {
|
||||
if keyStr, ok := key.(string); ok {
|
||||
c.config.OnEviction(keyStr, itm.value)
|
||||
}
|
||||
}
|
||||
if keyStr, ok := key.(string); ok {
|
||||
c.config.OnEviction(keyStr, itm.value)
|
||||
}
|
||||
return true
|
||||
})
|
||||
}
|
||||
|
||||
c.data = sync.Map{}
|
||||
}
|
||||
c.data.Delete(key)
|
||||
count++
|
||||
return true
|
||||
})
|
||||
(&c.itemCount).Store(0)
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -31,7 +31,7 @@ func (d *DB) CreateAttachment(ctx context.Context, create *store.Attachment) (*s
|
|||
}
|
||||
args := []any{create.UID, create.Filename, create.Blob, create.Type, create.Size, create.CreatorID, create.MemoID, storageType, create.Reference, payloadString}
|
||||
|
||||
stmt := "INSERT INTO `resource` (" + strings.Join(fields, ", ") + ") VALUES (" + strings.Join(placeholder, ", ") + ")"
|
||||
stmt := "INSERT INTO `attachment` (" + strings.Join(fields, ", ") + ") VALUES (" + strings.Join(placeholder, ", ") + ")"
|
||||
result, err := d.db.ExecContext(ctx, stmt, args...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
|
@ -50,38 +50,38 @@ func (d *DB) ListAttachments(ctx context.Context, find *store.FindAttachment) ([
|
|||
where, args := []string{"1 = 1"}, []any{}
|
||||
|
||||
if v := find.ID; v != nil {
|
||||
where, args = append(where, "`resource`.`id` = ?"), append(args, *v)
|
||||
where, args = append(where, "`attachment`.`id` = ?"), append(args, *v)
|
||||
}
|
||||
if v := find.UID; v != nil {
|
||||
where, args = append(where, "`resource`.`uid` = ?"), append(args, *v)
|
||||
where, args = append(where, "`attachment`.`uid` = ?"), append(args, *v)
|
||||
}
|
||||
if v := find.CreatorID; v != nil {
|
||||
where, args = append(where, "`resource`.`creator_id` = ?"), append(args, *v)
|
||||
where, args = append(where, "`attachment`.`creator_id` = ?"), append(args, *v)
|
||||
}
|
||||
if v := find.Filename; v != nil {
|
||||
where, args = append(where, "`resource`.`filename` = ?"), append(args, *v)
|
||||
where, args = append(where, "`attachment`.`filename` = ?"), append(args, *v)
|
||||
}
|
||||
if v := find.FilenameSearch; v != nil {
|
||||
where, args = append(where, "`resource`.`filename` LIKE ?"), append(args, "%"+*v+"%")
|
||||
where, args = append(where, "`attachment`.`filename` LIKE ?"), append(args, "%"+*v+"%")
|
||||
}
|
||||
if v := find.MemoID; v != nil {
|
||||
where, args = append(where, "`resource`.`memo_id` = ?"), append(args, *v)
|
||||
where, args = append(where, "`attachment`.`memo_id` = ?"), append(args, *v)
|
||||
}
|
||||
if len(find.MemoIDList) > 0 {
|
||||
placeholders := make([]string, 0, len(find.MemoIDList))
|
||||
for range find.MemoIDList {
|
||||
placeholders = append(placeholders, "?")
|
||||
}
|
||||
where = append(where, "`resource`.`memo_id` IN ("+strings.Join(placeholders, ",")+")")
|
||||
where = append(where, "`attachment`.`memo_id` IN ("+strings.Join(placeholders, ",")+")")
|
||||
for _, id := range find.MemoIDList {
|
||||
args = append(args, id)
|
||||
}
|
||||
}
|
||||
if find.HasRelatedMemo {
|
||||
where = append(where, "`resource`.`memo_id` IS NOT NULL")
|
||||
where = append(where, "`attachment`.`memo_id` IS NOT NULL")
|
||||
}
|
||||
if find.StorageType != nil {
|
||||
where, args = append(where, "`resource`.`storage_type` = ?"), append(args, find.StorageType.String())
|
||||
where, args = append(where, "`attachment`.`storage_type` = ?"), append(args, find.StorageType.String())
|
||||
}
|
||||
|
||||
if len(find.Filters) > 0 {
|
||||
|
|
@ -95,26 +95,26 @@ func (d *DB) ListAttachments(ctx context.Context, find *store.FindAttachment) ([
|
|||
}
|
||||
|
||||
fields := []string{
|
||||
"`resource`.`id` AS `id`",
|
||||
"`resource`.`uid` AS `uid`",
|
||||
"`resource`.`filename` AS `filename`",
|
||||
"`resource`.`type` AS `type`",
|
||||
"`resource`.`size` AS `size`",
|
||||
"`resource`.`creator_id` AS `creator_id`",
|
||||
"UNIX_TIMESTAMP(`resource`.`created_ts`) AS `created_ts`",
|
||||
"UNIX_TIMESTAMP(`resource`.`updated_ts`) AS `updated_ts`",
|
||||
"`resource`.`memo_id` AS `memo_id`",
|
||||
"`resource`.`storage_type` AS `storage_type`",
|
||||
"`resource`.`reference` AS `reference`",
|
||||
"`resource`.`payload` AS `payload`",
|
||||
"`attachment`.`id` AS `id`",
|
||||
"`attachment`.`uid` AS `uid`",
|
||||
"`attachment`.`filename` AS `filename`",
|
||||
"`attachment`.`type` AS `type`",
|
||||
"`attachment`.`size` AS `size`",
|
||||
"`attachment`.`creator_id` AS `creator_id`",
|
||||
"UNIX_TIMESTAMP(`attachment`.`created_ts`) AS `created_ts`",
|
||||
"UNIX_TIMESTAMP(`attachment`.`updated_ts`) AS `updated_ts`",
|
||||
"`attachment`.`memo_id` AS `memo_id`",
|
||||
"`attachment`.`storage_type` AS `storage_type`",
|
||||
"`attachment`.`reference` AS `reference`",
|
||||
"`attachment`.`payload` AS `payload`",
|
||||
"CASE WHEN `memo`.`uid` IS NOT NULL THEN `memo`.`uid` ELSE NULL END AS `memo_uid`",
|
||||
}
|
||||
if find.GetBlob {
|
||||
fields = append(fields, "`resource`.`blob` AS `blob`")
|
||||
fields = append(fields, "`attachment`.`blob` AS `blob`")
|
||||
}
|
||||
|
||||
query := "SELECT " + strings.Join(fields, ", ") + " FROM `resource`" + " " +
|
||||
"LEFT JOIN `memo` ON `resource`.`memo_id` = `memo`.`id`" + " " +
|
||||
query := "SELECT " + strings.Join(fields, ", ") + " FROM `attachment`" + " " +
|
||||
"LEFT JOIN `memo` ON `attachment`.`memo_id` = `memo`.`id`" + " " +
|
||||
"WHERE " + strings.Join(where, " AND ") + " " +
|
||||
"ORDER BY `updated_ts` DESC"
|
||||
if find.Limit != nil {
|
||||
|
|
@ -216,7 +216,7 @@ func (d *DB) UpdateAttachment(ctx context.Context, update *store.UpdateAttachmen
|
|||
}
|
||||
|
||||
args = append(args, update.ID)
|
||||
stmt := "UPDATE `resource` SET " + strings.Join(set, ", ") + " WHERE `id` = ?"
|
||||
stmt := "UPDATE `attachment` SET " + strings.Join(set, ", ") + " WHERE `id` = ?"
|
||||
result, err := d.db.ExecContext(ctx, stmt, args...)
|
||||
if err != nil {
|
||||
return err
|
||||
|
|
@ -228,7 +228,7 @@ func (d *DB) UpdateAttachment(ctx context.Context, update *store.UpdateAttachmen
|
|||
}
|
||||
|
||||
func (d *DB) DeleteAttachment(ctx context.Context, delete *store.DeleteAttachment) error {
|
||||
stmt := "DELETE FROM `resource` WHERE `id` = ?"
|
||||
stmt := "DELETE FROM `attachment` WHERE `id` = ?"
|
||||
result, err := d.db.ExecContext(ctx, stmt, delete.ID)
|
||||
if err != nil {
|
||||
return err
|
||||
|
|
|
|||
|
|
@ -63,7 +63,11 @@ func (d *DB) ListInboxes(ctx context.Context, find *store.FindInbox) ([]*store.I
|
|||
if find.MessageType != nil {
|
||||
// Filter by message type using JSON extraction
|
||||
// Note: The type field in JSON is stored as string representation of the enum name
|
||||
where, args = append(where, "JSON_EXTRACT(`message`, '$.type') = ?"), append(args, find.MessageType.String())
|
||||
if *find.MessageType == storepb.InboxMessage_TYPE_UNSPECIFIED {
|
||||
where, args = append(where, "(JSON_EXTRACT(`message`, '$.type') IS NULL OR JSON_EXTRACT(`message`, '$.type') = ?)"), append(args, find.MessageType.String())
|
||||
} else {
|
||||
where, args = append(where, "JSON_EXTRACT(`message`, '$.type') = ?"), append(args, find.MessageType.String())
|
||||
}
|
||||
}
|
||||
|
||||
query := "SELECT `id`, UNIX_TIMESTAMP(`created_ts`), `sender_id`, `receiver_id`, `status`, `message` FROM `inbox` WHERE " + strings.Join(where, " AND ") + " ORDER BY `created_ts` DESC"
|
||||
|
|
|
|||
|
|
@ -26,6 +26,18 @@ func (d *DB) CreateMemo(ctx context.Context, create *store.Memo) (*store.Memo, e
|
|||
}
|
||||
args := []any{create.UID, create.CreatorID, create.Content, create.Visibility, payload}
|
||||
|
||||
// Add custom timestamps if provided
|
||||
if create.CreatedTs != 0 {
|
||||
fields = append(fields, "`created_ts`")
|
||||
placeholder = append(placeholder, "?")
|
||||
args = append(args, create.CreatedTs)
|
||||
}
|
||||
if create.UpdatedTs != 0 {
|
||||
fields = append(fields, "`updated_ts`")
|
||||
placeholder = append(placeholder, "?")
|
||||
args = append(args, create.UpdatedTs)
|
||||
}
|
||||
|
||||
stmt := "INSERT INTO `memo` (" + strings.Join(fields, ", ") + ") VALUES (" + strings.Join(placeholder, ", ") + ")"
|
||||
result, err := d.db.ExecContext(ctx, stmt, args...)
|
||||
if err != nil {
|
||||
|
|
|
|||
|
|
@ -1,162 +0,0 @@
|
|||
package mysql
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/usememos/memos/plugin/filter"
|
||||
)
|
||||
|
||||
func TestConvertExprToSQL(t *testing.T) {
|
||||
tests := []struct {
|
||||
filter string
|
||||
want string
|
||||
args []any
|
||||
}{
|
||||
{
|
||||
filter: `tag in ["tag1", "tag2"]`,
|
||||
want: "((JSON_CONTAINS(JSON_EXTRACT(`memo`.`payload`, '$.tags'), ?) OR JSON_EXTRACT(`memo`.`payload`, '$.tags') LIKE ?) OR (JSON_CONTAINS(JSON_EXTRACT(`memo`.`payload`, '$.tags'), ?) OR JSON_EXTRACT(`memo`.`payload`, '$.tags') LIKE ?))",
|
||||
args: []any{`"tag1"`, `%"tag1/%`, `"tag2"`, `%"tag2/%`},
|
||||
},
|
||||
{
|
||||
filter: `!(tag in ["tag1", "tag2"])`,
|
||||
want: "NOT (((JSON_CONTAINS(JSON_EXTRACT(`memo`.`payload`, '$.tags'), ?) OR JSON_EXTRACT(`memo`.`payload`, '$.tags') LIKE ?) OR (JSON_CONTAINS(JSON_EXTRACT(`memo`.`payload`, '$.tags'), ?) OR JSON_EXTRACT(`memo`.`payload`, '$.tags') LIKE ?)))",
|
||||
args: []any{`"tag1"`, `%"tag1/%`, `"tag2"`, `%"tag2/%`},
|
||||
},
|
||||
{
|
||||
filter: `content.contains("memos")`,
|
||||
want: "`memo`.`content` LIKE ?",
|
||||
args: []any{"%memos%"},
|
||||
},
|
||||
{
|
||||
filter: `visibility in ["PUBLIC"]`,
|
||||
want: "`memo`.`visibility` IN (?)",
|
||||
args: []any{"PUBLIC"},
|
||||
},
|
||||
{
|
||||
filter: `visibility in ["PUBLIC", "PRIVATE"]`,
|
||||
want: "`memo`.`visibility` IN (?,?)",
|
||||
args: []any{"PUBLIC", "PRIVATE"},
|
||||
},
|
||||
{
|
||||
filter: `tag in ['tag1'] || content.contains('hello')`,
|
||||
want: "((JSON_CONTAINS(JSON_EXTRACT(`memo`.`payload`, '$.tags'), ?) OR JSON_EXTRACT(`memo`.`payload`, '$.tags') LIKE ?) OR `memo`.`content` LIKE ?)",
|
||||
args: []any{`"tag1"`, `%"tag1/%`, "%hello%"},
|
||||
},
|
||||
{
|
||||
filter: `1`,
|
||||
want: "",
|
||||
args: []any{},
|
||||
},
|
||||
{
|
||||
filter: `pinned`,
|
||||
want: "`memo`.`pinned` IS TRUE",
|
||||
args: []any{},
|
||||
},
|
||||
{
|
||||
filter: `has_task_list`,
|
||||
want: "JSON_EXTRACT(`memo`.`payload`, '$.property.hasTaskList') = CAST('true' AS JSON)",
|
||||
args: []any{},
|
||||
},
|
||||
{
|
||||
filter: `has_task_list == true`,
|
||||
want: "JSON_EXTRACT(`memo`.`payload`, '$.property.hasTaskList') = CAST('true' AS JSON)",
|
||||
args: []any{},
|
||||
},
|
||||
{
|
||||
filter: `has_task_list != false`,
|
||||
want: "JSON_EXTRACT(`memo`.`payload`, '$.property.hasTaskList') != CAST('false' AS JSON)",
|
||||
args: []any{},
|
||||
},
|
||||
{
|
||||
filter: `has_task_list == false`,
|
||||
want: "JSON_EXTRACT(`memo`.`payload`, '$.property.hasTaskList') = CAST('false' AS JSON)",
|
||||
args: []any{},
|
||||
},
|
||||
{
|
||||
filter: `!has_task_list`,
|
||||
want: "NOT (JSON_EXTRACT(`memo`.`payload`, '$.property.hasTaskList') = CAST('true' AS JSON))",
|
||||
args: []any{},
|
||||
},
|
||||
{
|
||||
filter: `has_task_list && pinned`,
|
||||
want: "(JSON_EXTRACT(`memo`.`payload`, '$.property.hasTaskList') = CAST('true' AS JSON) AND `memo`.`pinned` IS TRUE)",
|
||||
args: []any{},
|
||||
},
|
||||
{
|
||||
filter: `has_task_list && content.contains("todo")`,
|
||||
want: "(JSON_EXTRACT(`memo`.`payload`, '$.property.hasTaskList') = CAST('true' AS JSON) AND `memo`.`content` LIKE ?)",
|
||||
args: []any{"%todo%"},
|
||||
},
|
||||
{
|
||||
filter: `created_ts > now() - 60 * 60 * 24`,
|
||||
want: "UNIX_TIMESTAMP(`memo`.`created_ts`) > ?",
|
||||
args: []any{time.Now().Unix() - 60*60*24},
|
||||
},
|
||||
{
|
||||
filter: `size(tags) == 0`,
|
||||
want: "JSON_LENGTH(COALESCE(JSON_EXTRACT(`memo`.`payload`, '$.tags'), JSON_ARRAY())) = ?",
|
||||
args: []any{int64(0)},
|
||||
},
|
||||
{
|
||||
filter: `size(tags) > 0`,
|
||||
want: "JSON_LENGTH(COALESCE(JSON_EXTRACT(`memo`.`payload`, '$.tags'), JSON_ARRAY())) > ?",
|
||||
args: []any{int64(0)},
|
||||
},
|
||||
{
|
||||
filter: `"work" in tags`,
|
||||
want: "JSON_CONTAINS(JSON_EXTRACT(`memo`.`payload`, '$.tags'), ?)",
|
||||
args: []any{`"work"`},
|
||||
},
|
||||
{
|
||||
filter: `size(tags) == 2`,
|
||||
want: "JSON_LENGTH(COALESCE(JSON_EXTRACT(`memo`.`payload`, '$.tags'), JSON_ARRAY())) = ?",
|
||||
args: []any{int64(2)},
|
||||
},
|
||||
{
|
||||
filter: `has_link == true`,
|
||||
want: "JSON_EXTRACT(`memo`.`payload`, '$.property.hasLink') = CAST('true' AS JSON)",
|
||||
args: []any{},
|
||||
},
|
||||
{
|
||||
filter: `has_code == false`,
|
||||
want: "JSON_EXTRACT(`memo`.`payload`, '$.property.hasCode') = CAST('false' AS JSON)",
|
||||
args: []any{},
|
||||
},
|
||||
{
|
||||
filter: `has_incomplete_tasks != false`,
|
||||
want: "JSON_EXTRACT(`memo`.`payload`, '$.property.hasIncompleteTasks') != CAST('false' AS JSON)",
|
||||
args: []any{},
|
||||
},
|
||||
{
|
||||
filter: `has_link`,
|
||||
want: "JSON_EXTRACT(`memo`.`payload`, '$.property.hasLink') = CAST('true' AS JSON)",
|
||||
args: []any{},
|
||||
},
|
||||
{
|
||||
filter: `has_code`,
|
||||
want: "JSON_EXTRACT(`memo`.`payload`, '$.property.hasCode') = CAST('true' AS JSON)",
|
||||
args: []any{},
|
||||
},
|
||||
{
|
||||
filter: `has_incomplete_tasks`,
|
||||
want: "JSON_EXTRACT(`memo`.`payload`, '$.property.hasIncompleteTasks') = CAST('true' AS JSON)",
|
||||
args: []any{},
|
||||
},
|
||||
}
|
||||
|
||||
engine, err := filter.DefaultEngine()
|
||||
require.NoError(t, err)
|
||||
|
||||
for _, tt := range tests {
|
||||
stmt, err := engine.CompileToStatement(context.Background(), tt.filter, filter.RenderOptions{
|
||||
Dialect: filter.DialectMySQL,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, tt.want, stmt.SQL)
|
||||
require.Equal(t, tt.args, stmt.Args)
|
||||
}
|
||||
}
|
||||
|
|
@ -10,7 +10,7 @@ import (
|
|||
)
|
||||
|
||||
func (d *DB) UpsertMemoRelation(ctx context.Context, create *store.MemoRelation) (*store.MemoRelation, error) {
|
||||
stmt := "INSERT INTO `memo_relation` (`memo_id`, `related_memo_id`, `type`) VALUES (?, ?, ?)"
|
||||
stmt := "INSERT INTO `memo_relation` (`memo_id`, `related_memo_id`, `type`) VALUES (?, ?, ?) ON DUPLICATE KEY UPDATE `type` = `type`"
|
||||
_, err := d.db.ExecContext(
|
||||
ctx,
|
||||
stmt,
|
||||
|
|
|
|||
|
|
@ -30,7 +30,7 @@ func (d *DB) CreateAttachment(ctx context.Context, create *store.Attachment) (*s
|
|||
}
|
||||
args := []any{create.UID, create.Filename, create.Blob, create.Type, create.Size, create.CreatorID, create.MemoID, storageType, create.Reference, payloadString}
|
||||
|
||||
stmt := "INSERT INTO resource (" + strings.Join(fields, ", ") + ") VALUES (" + placeholders(len(args)) + ") RETURNING id, created_ts, updated_ts"
|
||||
stmt := "INSERT INTO attachment (" + strings.Join(fields, ", ") + ") VALUES (" + placeholders(len(args)) + ") RETURNING id, created_ts, updated_ts"
|
||||
if err := d.db.QueryRowContext(ctx, stmt, args...).Scan(&create.ID, &create.CreatedTs, &create.UpdatedTs); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
@ -41,22 +41,22 @@ func (d *DB) ListAttachments(ctx context.Context, find *store.FindAttachment) ([
|
|||
where, args := []string{"1 = 1"}, []any{}
|
||||
|
||||
if v := find.ID; v != nil {
|
||||
where, args = append(where, "resource.id = "+placeholder(len(args)+1)), append(args, *v)
|
||||
where, args = append(where, "attachment.id = "+placeholder(len(args)+1)), append(args, *v)
|
||||
}
|
||||
if v := find.UID; v != nil {
|
||||
where, args = append(where, "resource.uid = "+placeholder(len(args)+1)), append(args, *v)
|
||||
where, args = append(where, "attachment.uid = "+placeholder(len(args)+1)), append(args, *v)
|
||||
}
|
||||
if v := find.CreatorID; v != nil {
|
||||
where, args = append(where, "resource.creator_id = "+placeholder(len(args)+1)), append(args, *v)
|
||||
where, args = append(where, "attachment.creator_id = "+placeholder(len(args)+1)), append(args, *v)
|
||||
}
|
||||
if v := find.Filename; v != nil {
|
||||
where, args = append(where, "resource.filename = "+placeholder(len(args)+1)), append(args, *v)
|
||||
where, args = append(where, "attachment.filename = "+placeholder(len(args)+1)), append(args, *v)
|
||||
}
|
||||
if v := find.FilenameSearch; v != nil {
|
||||
where, args = append(where, "resource.filename LIKE "+placeholder(len(args)+1)), append(args, fmt.Sprintf("%%%s%%", *v))
|
||||
where, args = append(where, "attachment.filename LIKE "+placeholder(len(args)+1)), append(args, fmt.Sprintf("%%%s%%", *v))
|
||||
}
|
||||
if v := find.MemoID; v != nil {
|
||||
where, args = append(where, "resource.memo_id = "+placeholder(len(args)+1)), append(args, *v)
|
||||
where, args = append(where, "attachment.memo_id = "+placeholder(len(args)+1)), append(args, *v)
|
||||
}
|
||||
if len(find.MemoIDList) > 0 {
|
||||
holders := make([]string, 0, len(find.MemoIDList))
|
||||
|
|
@ -64,13 +64,13 @@ func (d *DB) ListAttachments(ctx context.Context, find *store.FindAttachment) ([
|
|||
holders = append(holders, placeholder(len(args)+1))
|
||||
args = append(args, id)
|
||||
}
|
||||
where = append(where, "resource.memo_id IN ("+strings.Join(holders, ", ")+")")
|
||||
where = append(where, "attachment.memo_id IN ("+strings.Join(holders, ", ")+")")
|
||||
}
|
||||
if find.HasRelatedMemo {
|
||||
where = append(where, "resource.memo_id IS NOT NULL")
|
||||
where = append(where, "attachment.memo_id IS NOT NULL")
|
||||
}
|
||||
if v := find.StorageType; v != nil {
|
||||
where, args = append(where, "resource.storage_type = "+placeholder(len(args)+1)), append(args, v.String())
|
||||
where, args = append(where, "attachment.storage_type = "+placeholder(len(args)+1)), append(args, v.String())
|
||||
}
|
||||
|
||||
if len(find.Filters) > 0 {
|
||||
|
|
@ -84,31 +84,31 @@ func (d *DB) ListAttachments(ctx context.Context, find *store.FindAttachment) ([
|
|||
}
|
||||
|
||||
fields := []string{
|
||||
"resource.id AS id",
|
||||
"resource.uid AS uid",
|
||||
"resource.filename AS filename",
|
||||
"resource.type AS type",
|
||||
"resource.size AS size",
|
||||
"resource.creator_id AS creator_id",
|
||||
"resource.created_ts AS created_ts",
|
||||
"resource.updated_ts AS updated_ts",
|
||||
"resource.memo_id AS memo_id",
|
||||
"resource.storage_type AS storage_type",
|
||||
"resource.reference AS reference",
|
||||
"resource.payload AS payload",
|
||||
"attachment.id AS id",
|
||||
"attachment.uid AS uid",
|
||||
"attachment.filename AS filename",
|
||||
"attachment.type AS type",
|
||||
"attachment.size AS size",
|
||||
"attachment.creator_id AS creator_id",
|
||||
"attachment.created_ts AS created_ts",
|
||||
"attachment.updated_ts AS updated_ts",
|
||||
"attachment.memo_id AS memo_id",
|
||||
"attachment.storage_type AS storage_type",
|
||||
"attachment.reference AS reference",
|
||||
"attachment.payload AS payload",
|
||||
"CASE WHEN memo.uid IS NOT NULL THEN memo.uid ELSE NULL END AS memo_uid",
|
||||
}
|
||||
if find.GetBlob {
|
||||
fields = append(fields, "resource.blob AS blob")
|
||||
fields = append(fields, "attachment.blob AS blob")
|
||||
}
|
||||
|
||||
query := fmt.Sprintf(`
|
||||
SELECT
|
||||
%s
|
||||
FROM resource
|
||||
LEFT JOIN memo ON resource.memo_id = memo.id
|
||||
FROM attachment
|
||||
LEFT JOIN memo ON attachment.memo_id = memo.id
|
||||
WHERE %s
|
||||
ORDER BY resource.updated_ts DESC
|
||||
ORDER BY attachment.updated_ts DESC
|
||||
`, strings.Join(fields, ", "), strings.Join(where, " AND "))
|
||||
if find.Limit != nil {
|
||||
query = fmt.Sprintf("%s LIMIT %d", query, *find.Limit)
|
||||
|
|
@ -196,7 +196,7 @@ func (d *DB) UpdateAttachment(ctx context.Context, update *store.UpdateAttachmen
|
|||
set, args = append(set, "payload = "+placeholder(len(args)+1)), append(args, string(bytes))
|
||||
}
|
||||
|
||||
stmt := `UPDATE resource SET ` + strings.Join(set, ", ") + ` WHERE id = ` + placeholder(len(args)+1)
|
||||
stmt := `UPDATE attachment SET ` + strings.Join(set, ", ") + ` WHERE id = ` + placeholder(len(args)+1)
|
||||
args = append(args, update.ID)
|
||||
result, err := d.db.ExecContext(ctx, stmt, args...)
|
||||
if err != nil {
|
||||
|
|
@ -209,7 +209,7 @@ func (d *DB) UpdateAttachment(ctx context.Context, update *store.UpdateAttachmen
|
|||
}
|
||||
|
||||
func (d *DB) DeleteAttachment(ctx context.Context, delete *store.DeleteAttachment) error {
|
||||
stmt := `DELETE FROM resource WHERE id = $1`
|
||||
stmt := `DELETE FROM attachment WHERE id = $1`
|
||||
result, err := d.db.ExecContext(ctx, stmt, delete.ID)
|
||||
if err != nil {
|
||||
return err
|
||||
|
|
|
|||
|
|
@ -53,7 +53,12 @@ func (d *DB) ListInboxes(ctx context.Context, find *store.FindInbox) ([]*store.I
|
|||
if find.MessageType != nil {
|
||||
// Filter by message type using PostgreSQL JSON extraction
|
||||
// Note: The type field in JSON is stored as string representation of the enum name
|
||||
where, args = append(where, "message->>'type' = "+placeholder(len(args)+1)), append(args, find.MessageType.String())
|
||||
// Cast to JSONB since the column is TEXT
|
||||
if *find.MessageType == storepb.InboxMessage_TYPE_UNSPECIFIED {
|
||||
where, args = append(where, "(message::JSONB->>'type' IS NULL OR message::JSONB->>'type' = "+placeholder(len(args)+1)+")"), append(args, find.MessageType.String())
|
||||
} else {
|
||||
where, args = append(where, "message::JSONB->>'type' = "+placeholder(len(args)+1)), append(args, find.MessageType.String())
|
||||
}
|
||||
}
|
||||
|
||||
query := "SELECT id, created_ts, sender_id, receiver_id, status, message FROM inbox WHERE " + strings.Join(where, " AND ") + " ORDER BY created_ts DESC"
|
||||
|
|
|
|||
|
|
@ -25,6 +25,16 @@ func (d *DB) CreateMemo(ctx context.Context, create *store.Memo) (*store.Memo, e
|
|||
}
|
||||
args := []any{create.UID, create.CreatorID, create.Content, create.Visibility, payload}
|
||||
|
||||
// Add custom timestamps if provided
|
||||
if create.CreatedTs != 0 {
|
||||
fields = append(fields, "created_ts")
|
||||
args = append(args, create.CreatedTs)
|
||||
}
|
||||
if create.UpdatedTs != 0 {
|
||||
fields = append(fields, "updated_ts")
|
||||
args = append(args, create.UpdatedTs)
|
||||
}
|
||||
|
||||
stmt := "INSERT INTO memo (" + strings.Join(fields, ", ") + ") VALUES (" + placeholders(len(args)) + ") RETURNING id, created_ts, updated_ts, row_status"
|
||||
if err := d.db.QueryRowContext(ctx, stmt, args...).Scan(
|
||||
&create.ID,
|
||||
|
|
|
|||
|
|
@ -1,160 +0,0 @@
|
|||
package postgres
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/usememos/memos/plugin/filter"
|
||||
)
|
||||
|
||||
func TestConvertExprToSQL(t *testing.T) {
|
||||
tests := []struct {
|
||||
filter string
|
||||
want string
|
||||
args []any
|
||||
}{
|
||||
{
|
||||
filter: `tag in ["tag1", "tag2"]`,
|
||||
want: "((memo.payload->'tags' @> jsonb_build_array($1::json) OR (memo.payload->'tags')::text LIKE $2) OR (memo.payload->'tags' @> jsonb_build_array($3::json) OR (memo.payload->'tags')::text LIKE $4))",
|
||||
args: []any{`"tag1"`, `%"tag1/%`, `"tag2"`, `%"tag2/%`},
|
||||
},
|
||||
{
|
||||
filter: `!(tag in ["tag1", "tag2"])`,
|
||||
want: "NOT (((memo.payload->'tags' @> jsonb_build_array($1::json) OR (memo.payload->'tags')::text LIKE $2) OR (memo.payload->'tags' @> jsonb_build_array($3::json) OR (memo.payload->'tags')::text LIKE $4)))",
|
||||
args: []any{`"tag1"`, `%"tag1/%`, `"tag2"`, `%"tag2/%`},
|
||||
},
|
||||
{
|
||||
filter: `content.contains("memos")`,
|
||||
want: "memo.content ILIKE $1",
|
||||
args: []any{"%memos%"},
|
||||
},
|
||||
{
|
||||
filter: `visibility in ["PUBLIC"]`,
|
||||
want: "memo.visibility IN ($1)",
|
||||
args: []any{"PUBLIC"},
|
||||
},
|
||||
{
|
||||
filter: `visibility in ["PUBLIC", "PRIVATE"]`,
|
||||
want: "memo.visibility IN ($1,$2)",
|
||||
args: []any{"PUBLIC", "PRIVATE"},
|
||||
},
|
||||
{
|
||||
filter: `tag in ['tag1'] || content.contains('hello')`,
|
||||
want: "((memo.payload->'tags' @> jsonb_build_array($1::json) OR (memo.payload->'tags')::text LIKE $2) OR memo.content ILIKE $3)",
|
||||
args: []any{`"tag1"`, `%"tag1/%`, "%hello%"},
|
||||
},
|
||||
{
|
||||
filter: `1`,
|
||||
want: "",
|
||||
args: []any{},
|
||||
},
|
||||
{
|
||||
filter: `pinned`,
|
||||
want: "memo.pinned IS TRUE",
|
||||
args: []any{},
|
||||
},
|
||||
{
|
||||
filter: `has_task_list`,
|
||||
want: "(memo.payload->'property'->>'hasTaskList')::boolean IS TRUE",
|
||||
args: []any{},
|
||||
},
|
||||
{
|
||||
filter: `has_task_list == true`,
|
||||
want: "(memo.payload->'property'->>'hasTaskList')::boolean = $1",
|
||||
args: []any{true},
|
||||
},
|
||||
{
|
||||
filter: `has_task_list != false`,
|
||||
want: "(memo.payload->'property'->>'hasTaskList')::boolean != $1",
|
||||
args: []any{false},
|
||||
},
|
||||
{
|
||||
filter: `has_task_list == false`,
|
||||
want: "(memo.payload->'property'->>'hasTaskList')::boolean = $1",
|
||||
args: []any{false},
|
||||
},
|
||||
{
|
||||
filter: `!has_task_list`,
|
||||
want: "NOT ((memo.payload->'property'->>'hasTaskList')::boolean IS TRUE)",
|
||||
args: []any{},
|
||||
},
|
||||
{
|
||||
filter: `has_task_list && pinned`,
|
||||
want: "((memo.payload->'property'->>'hasTaskList')::boolean IS TRUE AND memo.pinned IS TRUE)",
|
||||
args: []any{},
|
||||
},
|
||||
{
|
||||
filter: `has_task_list && content.contains("todo")`,
|
||||
want: "((memo.payload->'property'->>'hasTaskList')::boolean IS TRUE AND memo.content ILIKE $1)",
|
||||
args: []any{"%todo%"},
|
||||
},
|
||||
{
|
||||
filter: `created_ts > now() - 60 * 60 * 24`,
|
||||
want: "memo.created_ts > $1",
|
||||
args: []any{time.Now().Unix() - 60*60*24},
|
||||
},
|
||||
{
|
||||
filter: `size(tags) == 0`,
|
||||
want: "jsonb_array_length(COALESCE(memo.payload->'tags', '[]'::jsonb)) = $1",
|
||||
args: []any{int64(0)},
|
||||
},
|
||||
{
|
||||
filter: `size(tags) > 0`,
|
||||
want: "jsonb_array_length(COALESCE(memo.payload->'tags', '[]'::jsonb)) > $1",
|
||||
args: []any{int64(0)},
|
||||
},
|
||||
{
|
||||
filter: `"work" in tags`,
|
||||
want: "memo.payload->'tags' @> jsonb_build_array($1::json)",
|
||||
args: []any{`"work"`},
|
||||
},
|
||||
{
|
||||
filter: `size(tags) == 2`,
|
||||
want: "jsonb_array_length(COALESCE(memo.payload->'tags', '[]'::jsonb)) = $1",
|
||||
args: []any{int64(2)},
|
||||
},
|
||||
{
|
||||
filter: `has_link == true`,
|
||||
want: "(memo.payload->'property'->>'hasLink')::boolean = $1",
|
||||
args: []any{true},
|
||||
},
|
||||
{
|
||||
filter: `has_code == false`,
|
||||
want: "(memo.payload->'property'->>'hasCode')::boolean = $1",
|
||||
args: []any{false},
|
||||
},
|
||||
{
|
||||
filter: `has_incomplete_tasks != false`,
|
||||
want: "(memo.payload->'property'->>'hasIncompleteTasks')::boolean != $1",
|
||||
args: []any{false},
|
||||
},
|
||||
{
|
||||
filter: `has_link`,
|
||||
want: "(memo.payload->'property'->>'hasLink')::boolean IS TRUE",
|
||||
args: []any{},
|
||||
},
|
||||
{
|
||||
filter: `has_code`,
|
||||
want: "(memo.payload->'property'->>'hasCode')::boolean IS TRUE",
|
||||
args: []any{},
|
||||
},
|
||||
{
|
||||
filter: `has_incomplete_tasks`,
|
||||
want: "(memo.payload->'property'->>'hasIncompleteTasks')::boolean IS TRUE",
|
||||
args: []any{},
|
||||
},
|
||||
}
|
||||
|
||||
engine, err := filter.DefaultEngine()
|
||||
require.NoError(t, err)
|
||||
|
||||
for _, tt := range tests {
|
||||
stmt, err := engine.CompileToStatement(context.Background(), tt.filter, filter.RenderOptions{Dialect: filter.DialectPostgres})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, tt.want, stmt.SQL)
|
||||
require.Equal(t, tt.args, stmt.Args)
|
||||
}
|
||||
}
|
||||
|
|
@ -17,6 +17,7 @@ func (d *DB) UpsertMemoRelation(ctx context.Context, create *store.MemoRelation)
|
|||
type
|
||||
)
|
||||
VALUES (` + placeholders(3) + `)
|
||||
ON CONFLICT (memo_id, related_memo_id, type) DO UPDATE SET type = EXCLUDED.type
|
||||
RETURNING memo_id, related_memo_id, type
|
||||
`
|
||||
memoRelation := &store.MemoRelation{}
|
||||
|
|
|
|||
|
|
@ -31,7 +31,7 @@ func (d *DB) CreateAttachment(ctx context.Context, create *store.Attachment) (*s
|
|||
}
|
||||
args := []any{create.UID, create.Filename, create.Blob, create.Type, create.Size, create.CreatorID, create.MemoID, storageType, create.Reference, payloadString}
|
||||
|
||||
stmt := "INSERT INTO `resource` (" + strings.Join(fields, ", ") + ") VALUES (" + strings.Join(placeholder, ", ") + ") RETURNING `id`, `created_ts`, `updated_ts`"
|
||||
stmt := "INSERT INTO `attachment` (" + strings.Join(fields, ", ") + ") VALUES (" + strings.Join(placeholder, ", ") + ") RETURNING `id`, `created_ts`, `updated_ts`"
|
||||
if err := d.db.QueryRowContext(ctx, stmt, args...).Scan(&create.ID, &create.CreatedTs, &create.UpdatedTs); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
@ -43,38 +43,38 @@ func (d *DB) ListAttachments(ctx context.Context, find *store.FindAttachment) ([
|
|||
where, args := []string{"1 = 1"}, []any{}
|
||||
|
||||
if v := find.ID; v != nil {
|
||||
where, args = append(where, "`resource`.`id` = ?"), append(args, *v)
|
||||
where, args = append(where, "`attachment`.`id` = ?"), append(args, *v)
|
||||
}
|
||||
if v := find.UID; v != nil {
|
||||
where, args = append(where, "`resource`.`uid` = ?"), append(args, *v)
|
||||
where, args = append(where, "`attachment`.`uid` = ?"), append(args, *v)
|
||||
}
|
||||
if v := find.CreatorID; v != nil {
|
||||
where, args = append(where, "`resource`.`creator_id` = ?"), append(args, *v)
|
||||
where, args = append(where, "`attachment`.`creator_id` = ?"), append(args, *v)
|
||||
}
|
||||
if v := find.Filename; v != nil {
|
||||
where, args = append(where, "`resource`.`filename` = ?"), append(args, *v)
|
||||
where, args = append(where, "`attachment`.`filename` = ?"), append(args, *v)
|
||||
}
|
||||
if v := find.FilenameSearch; v != nil {
|
||||
where, args = append(where, "`resource`.`filename` LIKE ?"), append(args, fmt.Sprintf("%%%s%%", *v))
|
||||
where, args = append(where, "`attachment`.`filename` LIKE ?"), append(args, fmt.Sprintf("%%%s%%", *v))
|
||||
}
|
||||
if v := find.MemoID; v != nil {
|
||||
where, args = append(where, "`resource`.`memo_id` = ?"), append(args, *v)
|
||||
where, args = append(where, "`attachment`.`memo_id` = ?"), append(args, *v)
|
||||
}
|
||||
if len(find.MemoIDList) > 0 {
|
||||
placeholders := make([]string, 0, len(find.MemoIDList))
|
||||
for range find.MemoIDList {
|
||||
placeholders = append(placeholders, "?")
|
||||
}
|
||||
where = append(where, "`resource`.`memo_id` IN ("+strings.Join(placeholders, ",")+")")
|
||||
where = append(where, "`attachment`.`memo_id` IN ("+strings.Join(placeholders, ",")+")")
|
||||
for _, id := range find.MemoIDList {
|
||||
args = append(args, id)
|
||||
}
|
||||
}
|
||||
if find.HasRelatedMemo {
|
||||
where = append(where, "`resource`.`memo_id` IS NOT NULL")
|
||||
where = append(where, "`attachment`.`memo_id` IS NOT NULL")
|
||||
}
|
||||
if find.StorageType != nil {
|
||||
where, args = append(where, "`resource`.`storage_type` = ?"), append(args, find.StorageType.String())
|
||||
where, args = append(where, "`attachment`.`storage_type` = ?"), append(args, find.StorageType.String())
|
||||
}
|
||||
|
||||
if len(find.Filters) > 0 {
|
||||
|
|
@ -88,28 +88,28 @@ func (d *DB) ListAttachments(ctx context.Context, find *store.FindAttachment) ([
|
|||
}
|
||||
|
||||
fields := []string{
|
||||
"`resource`.`id` AS `id`",
|
||||
"`resource`.`uid` AS `uid`",
|
||||
"`resource`.`filename` AS `filename`",
|
||||
"`resource`.`type` AS `type`",
|
||||
"`resource`.`size` AS `size`",
|
||||
"`resource`.`creator_id` AS `creator_id`",
|
||||
"`resource`.`created_ts` AS `created_ts`",
|
||||
"`resource`.`updated_ts` AS `updated_ts`",
|
||||
"`resource`.`memo_id` AS `memo_id`",
|
||||
"`resource`.`storage_type` AS `storage_type`",
|
||||
"`resource`.`reference` AS `reference`",
|
||||
"`resource`.`payload` AS `payload`",
|
||||
"`attachment`.`id` AS `id`",
|
||||
"`attachment`.`uid` AS `uid`",
|
||||
"`attachment`.`filename` AS `filename`",
|
||||
"`attachment`.`type` AS `type`",
|
||||
"`attachment`.`size` AS `size`",
|
||||
"`attachment`.`creator_id` AS `creator_id`",
|
||||
"`attachment`.`created_ts` AS `created_ts`",
|
||||
"`attachment`.`updated_ts` AS `updated_ts`",
|
||||
"`attachment`.`memo_id` AS `memo_id`",
|
||||
"`attachment`.`storage_type` AS `storage_type`",
|
||||
"`attachment`.`reference` AS `reference`",
|
||||
"`attachment`.`payload` AS `payload`",
|
||||
"CASE WHEN `memo`.`uid` IS NOT NULL THEN `memo`.`uid` ELSE NULL END AS `memo_uid`",
|
||||
}
|
||||
if find.GetBlob {
|
||||
fields = append(fields, "`resource`.`blob` AS `blob`")
|
||||
fields = append(fields, "`attachment`.`blob` AS `blob`")
|
||||
}
|
||||
|
||||
query := "SELECT " + strings.Join(fields, ", ") + " FROM `resource`" + " " +
|
||||
"LEFT JOIN `memo` ON `resource`.`memo_id` = `memo`.`id`" + " " +
|
||||
query := "SELECT " + strings.Join(fields, ", ") + " FROM `attachment`" + " " +
|
||||
"LEFT JOIN `memo` ON `attachment`.`memo_id` = `memo`.`id`" + " " +
|
||||
"WHERE " + strings.Join(where, " AND ") + " " +
|
||||
"ORDER BY `resource`.`updated_ts` DESC"
|
||||
"ORDER BY `attachment`.`updated_ts` DESC"
|
||||
if find.Limit != nil {
|
||||
query = fmt.Sprintf("%s LIMIT %d", query, *find.Limit)
|
||||
if find.Offset != nil {
|
||||
|
|
@ -197,7 +197,7 @@ func (d *DB) UpdateAttachment(ctx context.Context, update *store.UpdateAttachmen
|
|||
}
|
||||
|
||||
args = append(args, update.ID)
|
||||
stmt := "UPDATE `resource` SET " + strings.Join(set, ", ") + " WHERE `id` = ?"
|
||||
stmt := "UPDATE `attachment` SET " + strings.Join(set, ", ") + " WHERE `id` = ?"
|
||||
result, err := d.db.ExecContext(ctx, stmt, args...)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "failed to update attachment")
|
||||
|
|
@ -209,7 +209,7 @@ func (d *DB) UpdateAttachment(ctx context.Context, update *store.UpdateAttachmen
|
|||
}
|
||||
|
||||
func (d *DB) DeleteAttachment(ctx context.Context, delete *store.DeleteAttachment) error {
|
||||
stmt := "DELETE FROM `resource` WHERE `id` = ?"
|
||||
stmt := "DELETE FROM `attachment` WHERE `id` = ?"
|
||||
result, err := d.db.ExecContext(ctx, stmt, delete.ID)
|
||||
if err != nil {
|
||||
return err
|
||||
|
|
|
|||
|
|
@ -55,7 +55,11 @@ func (d *DB) ListInboxes(ctx context.Context, find *store.FindInbox) ([]*store.I
|
|||
if find.MessageType != nil {
|
||||
// Filter by message type using JSON extraction
|
||||
// Note: The type field in JSON is stored as string representation of the enum name
|
||||
where, args = append(where, "JSON_EXTRACT(`message`, '$.type') = ?"), append(args, find.MessageType.String())
|
||||
if *find.MessageType == storepb.InboxMessage_TYPE_UNSPECIFIED {
|
||||
where, args = append(where, "(JSON_EXTRACT(`message`, '$.type') IS NULL OR JSON_EXTRACT(`message`, '$.type') = ?)"), append(args, find.MessageType.String())
|
||||
} else {
|
||||
where, args = append(where, "JSON_EXTRACT(`message`, '$.type') = ?"), append(args, find.MessageType.String())
|
||||
}
|
||||
}
|
||||
|
||||
query := "SELECT `id`, `created_ts`, `sender_id`, `receiver_id`, `status`, `message` FROM `inbox` WHERE " + strings.Join(where, " AND ") + " ORDER BY `created_ts` DESC"
|
||||
|
|
|
|||
|
|
@ -26,6 +26,18 @@ func (d *DB) CreateMemo(ctx context.Context, create *store.Memo) (*store.Memo, e
|
|||
}
|
||||
args := []any{create.UID, create.CreatorID, create.Content, create.Visibility, payload}
|
||||
|
||||
// Add custom timestamps if provided
|
||||
if create.CreatedTs != 0 {
|
||||
fields = append(fields, "`created_ts`")
|
||||
placeholder = append(placeholder, "?")
|
||||
args = append(args, create.CreatedTs)
|
||||
}
|
||||
if create.UpdatedTs != 0 {
|
||||
fields = append(fields, "`updated_ts`")
|
||||
placeholder = append(placeholder, "?")
|
||||
args = append(args, create.UpdatedTs)
|
||||
}
|
||||
|
||||
stmt := "INSERT INTO `memo` (" + strings.Join(fields, ", ") + ") VALUES (" + strings.Join(placeholder, ", ") + ") RETURNING `id`, `created_ts`, `updated_ts`, `row_status`"
|
||||
if err := d.db.QueryRowContext(ctx, stmt, args...).Scan(
|
||||
&create.ID,
|
||||
|
|
|
|||
|
|
@ -1,165 +0,0 @@
|
|||
package sqlite
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/usememos/memos/plugin/filter"
|
||||
)
|
||||
|
||||
func TestConvertExprToSQL(t *testing.T) {
|
||||
tests := []struct {
|
||||
filter string
|
||||
want string
|
||||
args []any
|
||||
}{
|
||||
{
|
||||
filter: `tag in ["tag1", "tag2"]`,
|
||||
want: "((JSON_EXTRACT(`memo`.`payload`, '$.tags') LIKE ? OR JSON_EXTRACT(`memo`.`payload`, '$.tags') LIKE ?) OR (JSON_EXTRACT(`memo`.`payload`, '$.tags') LIKE ? OR JSON_EXTRACT(`memo`.`payload`, '$.tags') LIKE ?))",
|
||||
args: []any{`%"tag1"%`, `%"tag1/%`, `%"tag2"%`, `%"tag2/%`},
|
||||
},
|
||||
{
|
||||
filter: `!(tag in ["tag1", "tag2"])`,
|
||||
want: "NOT (((JSON_EXTRACT(`memo`.`payload`, '$.tags') LIKE ? OR JSON_EXTRACT(`memo`.`payload`, '$.tags') LIKE ?) OR (JSON_EXTRACT(`memo`.`payload`, '$.tags') LIKE ? OR JSON_EXTRACT(`memo`.`payload`, '$.tags') LIKE ?)))",
|
||||
args: []any{`%"tag1"%`, `%"tag1/%`, `%"tag2"%`, `%"tag2/%`},
|
||||
},
|
||||
{
|
||||
filter: `content.contains("memos")`,
|
||||
want: "`memo`.`content` LIKE ?",
|
||||
args: []any{"%memos%"},
|
||||
},
|
||||
{
|
||||
filter: `visibility in ["PUBLIC"]`,
|
||||
want: "`memo`.`visibility` IN (?)",
|
||||
args: []any{"PUBLIC"},
|
||||
},
|
||||
{
|
||||
filter: `visibility in ["PUBLIC", "PRIVATE"]`,
|
||||
want: "`memo`.`visibility` IN (?,?)",
|
||||
args: []any{"PUBLIC", "PRIVATE"},
|
||||
},
|
||||
{
|
||||
filter: `tag in ['tag1'] || content.contains('hello')`,
|
||||
want: "((JSON_EXTRACT(`memo`.`payload`, '$.tags') LIKE ? OR JSON_EXTRACT(`memo`.`payload`, '$.tags') LIKE ?) OR `memo`.`content` LIKE ?)",
|
||||
args: []any{`%"tag1"%`, `%"tag1/%`, "%hello%"},
|
||||
},
|
||||
{
|
||||
filter: `1`,
|
||||
want: "",
|
||||
args: []any{},
|
||||
},
|
||||
{
|
||||
filter: `pinned`,
|
||||
want: "`memo`.`pinned` IS TRUE",
|
||||
args: []any{},
|
||||
},
|
||||
{
|
||||
filter: `has_task_list`,
|
||||
want: "JSON_EXTRACT(`memo`.`payload`, '$.property.hasTaskList') IS TRUE",
|
||||
args: []any{},
|
||||
},
|
||||
{
|
||||
filter: `has_code`,
|
||||
want: "JSON_EXTRACT(`memo`.`payload`, '$.property.hasCode') IS TRUE",
|
||||
args: []any{},
|
||||
},
|
||||
{
|
||||
filter: `has_task_list == true`,
|
||||
want: "JSON_EXTRACT(`memo`.`payload`, '$.property.hasTaskList') = 1",
|
||||
args: []any{},
|
||||
},
|
||||
{
|
||||
filter: `has_task_list != false`,
|
||||
want: "JSON_EXTRACT(`memo`.`payload`, '$.property.hasTaskList') != 0",
|
||||
args: []any{},
|
||||
},
|
||||
{
|
||||
filter: `has_task_list == false`,
|
||||
want: "JSON_EXTRACT(`memo`.`payload`, '$.property.hasTaskList') = 0",
|
||||
args: []any{},
|
||||
},
|
||||
{
|
||||
filter: `!has_task_list`,
|
||||
want: "NOT (JSON_EXTRACT(`memo`.`payload`, '$.property.hasTaskList') IS TRUE)",
|
||||
args: []any{},
|
||||
},
|
||||
{
|
||||
filter: `has_task_list && pinned`,
|
||||
want: "(JSON_EXTRACT(`memo`.`payload`, '$.property.hasTaskList') IS TRUE AND `memo`.`pinned` IS TRUE)",
|
||||
args: []any{},
|
||||
},
|
||||
{
|
||||
filter: `has_task_list && content.contains("todo")`,
|
||||
want: "(JSON_EXTRACT(`memo`.`payload`, '$.property.hasTaskList') IS TRUE AND `memo`.`content` LIKE ?)",
|
||||
args: []any{"%todo%"},
|
||||
},
|
||||
{
|
||||
filter: `created_ts > now() - 60 * 60 * 24`,
|
||||
want: "`memo`.`created_ts` > ?",
|
||||
args: []any{time.Now().Unix() - 60*60*24},
|
||||
},
|
||||
{
|
||||
filter: `size(tags) == 0`,
|
||||
want: "JSON_ARRAY_LENGTH(COALESCE(JSON_EXTRACT(`memo`.`payload`, '$.tags'), JSON_ARRAY())) = ?",
|
||||
args: []any{int64(0)},
|
||||
},
|
||||
{
|
||||
filter: `size(tags) > 0`,
|
||||
want: "JSON_ARRAY_LENGTH(COALESCE(JSON_EXTRACT(`memo`.`payload`, '$.tags'), JSON_ARRAY())) > ?",
|
||||
args: []any{int64(0)},
|
||||
},
|
||||
{
|
||||
filter: `"work" in tags`,
|
||||
want: "JSON_EXTRACT(`memo`.`payload`, '$.tags') LIKE ?",
|
||||
args: []any{`%"work"%`},
|
||||
},
|
||||
{
|
||||
filter: `size(tags) == 2`,
|
||||
want: "JSON_ARRAY_LENGTH(COALESCE(JSON_EXTRACT(`memo`.`payload`, '$.tags'), JSON_ARRAY())) = ?",
|
||||
args: []any{int64(2)},
|
||||
},
|
||||
{
|
||||
filter: `has_link == true`,
|
||||
want: "JSON_EXTRACT(`memo`.`payload`, '$.property.hasLink') IS TRUE",
|
||||
args: []any{},
|
||||
},
|
||||
{
|
||||
filter: `has_code == false`,
|
||||
want: "NOT(JSON_EXTRACT(`memo`.`payload`, '$.property.hasCode') IS TRUE)",
|
||||
args: []any{},
|
||||
},
|
||||
{
|
||||
filter: `has_incomplete_tasks != false`,
|
||||
want: "JSON_EXTRACT(`memo`.`payload`, '$.property.hasIncompleteTasks') IS TRUE",
|
||||
args: []any{},
|
||||
},
|
||||
{
|
||||
filter: `has_link`,
|
||||
want: "JSON_EXTRACT(`memo`.`payload`, '$.property.hasLink') IS TRUE",
|
||||
args: []any{},
|
||||
},
|
||||
{
|
||||
filter: `has_code`,
|
||||
want: "JSON_EXTRACT(`memo`.`payload`, '$.property.hasCode') IS TRUE",
|
||||
args: []any{},
|
||||
},
|
||||
{
|
||||
filter: `has_incomplete_tasks`,
|
||||
want: "JSON_EXTRACT(`memo`.`payload`, '$.property.hasIncompleteTasks') IS TRUE",
|
||||
args: []any{},
|
||||
},
|
||||
}
|
||||
|
||||
engine, err := filter.DefaultEngine()
|
||||
require.NoError(t, err)
|
||||
|
||||
for _, tt := range tests {
|
||||
stmt, err := engine.CompileToStatement(context.Background(), tt.filter, filter.RenderOptions{Dialect: filter.DialectSQLite})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, tt.want, stmt.SQL)
|
||||
require.Equal(t, tt.args, stmt.Args)
|
||||
}
|
||||
}
|
||||
|
|
@ -17,6 +17,7 @@ func (d *DB) UpsertMemoRelation(ctx context.Context, create *store.MemoRelation)
|
|||
type
|
||||
)
|
||||
VALUES (?, ?, ?)
|
||||
ON CONFLICT(memo_id, related_memo_id, type) DO UPDATE SET type = excluded.type
|
||||
RETURNING memo_id, related_memo_id, type
|
||||
`
|
||||
memoRelation := &store.MemoRelation{}
|
||||
|
|
|
|||
|
|
@ -33,6 +33,7 @@ func NewDB(profile *profile.Profile) (store.Driver, error) {
|
|||
// good practice to be explicit and prevent future surprises on SQLite upgrades.
|
||||
// - Journal mode set to WAL: it's the recommended journal mode for most applications
|
||||
// as it prevents locking issues.
|
||||
// - mmap size set to 0: it disables memory mapping, which can cause OOM errors on some systems.
|
||||
//
|
||||
// Notes:
|
||||
// - When using the `modernc.org/sqlite` driver, each pragma must be prefixed with `_pragma=`.
|
||||
|
|
@ -41,7 +42,7 @@ func NewDB(profile *profile.Profile) (store.Driver, error) {
|
|||
// - https://pkg.go.dev/modernc.org/sqlite#Driver.Open
|
||||
// - https://www.sqlite.org/sharedcache.html
|
||||
// - https://www.sqlite.org/pragma.html
|
||||
sqliteDB, err := sql.Open("sqlite", profile.DSN+"?_pragma=foreign_keys(0)&_pragma=busy_timeout(10000)&_pragma=journal_mode(WAL)")
|
||||
sqliteDB, err := sql.Open("sqlite", profile.DSN+"?_pragma=foreign_keys(0)&_pragma=busy_timeout(10000)&_pragma=journal_mode(WAL)&_pragma=mmap_size(0)")
|
||||
if err != nil {
|
||||
return nil, errors.Wrapf(err, "failed to open db with dsn: %s", profile.DSN)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -138,5 +138,22 @@ func (s *Store) UpdateMemo(ctx context.Context, update *UpdateMemo) error {
|
|||
}
|
||||
|
||||
func (s *Store) DeleteMemo(ctx context.Context, delete *DeleteMemo) error {
|
||||
// Clean up memo_relation records where this memo is either the source or target.
|
||||
if err := s.driver.DeleteMemoRelation(ctx, &DeleteMemoRelation{MemoID: &delete.ID}); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := s.driver.DeleteMemoRelation(ctx, &DeleteMemoRelation{RelatedMemoID: &delete.ID}); err != nil {
|
||||
return err
|
||||
}
|
||||
// Clean up attachments linked to this memo.
|
||||
attachments, err := s.ListAttachments(ctx, &FindAttachment{MemoID: &delete.ID})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
for _, attachment := range attachments {
|
||||
if err := s.DeleteAttachment(ctx, &DeleteAttachment{ID: attachment.ID}); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return s.driver.DeleteMemo(ctx, delete)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -0,0 +1 @@
|
|||
RENAME TABLE resource TO attachment;
|
||||
|
|
@ -0,0 +1 @@
|
|||
DROP TABLE IF EXISTS memo_organizer;
|
||||
|
|
@ -0,0 +1 @@
|
|||
UPDATE `user` SET `role` = 'ADMIN' WHERE `role` = 'HOST';
|
||||
|
|
@ -42,14 +42,6 @@ CREATE TABLE `memo` (
|
|||
`payload` JSON NOT NULL
|
||||
);
|
||||
|
||||
-- memo_organizer
|
||||
CREATE TABLE `memo_organizer` (
|
||||
`memo_id` INT NOT NULL,
|
||||
`user_id` INT NOT NULL,
|
||||
`pinned` INT NOT NULL DEFAULT '0',
|
||||
UNIQUE(`memo_id`,`user_id`)
|
||||
);
|
||||
|
||||
-- memo_relation
|
||||
CREATE TABLE `memo_relation` (
|
||||
`memo_id` INT NOT NULL,
|
||||
|
|
@ -58,8 +50,8 @@ CREATE TABLE `memo_relation` (
|
|||
UNIQUE(`memo_id`,`related_memo_id`,`type`)
|
||||
);
|
||||
|
||||
-- resource
|
||||
CREATE TABLE `resource` (
|
||||
-- attachment
|
||||
CREATE TABLE `attachment` (
|
||||
`id` INT NOT NULL AUTO_INCREMENT PRIMARY KEY,
|
||||
`uid` VARCHAR(256) NOT NULL UNIQUE,
|
||||
`creator_id` INT NOT NULL,
|
||||
|
|
|
|||
|
|
@ -0,0 +1 @@
|
|||
ALTER TABLE resource RENAME TO attachment;
|
||||
|
|
@ -0,0 +1 @@
|
|||
DROP TABLE IF EXISTS memo_organizer;
|
||||
|
|
@ -0,0 +1 @@
|
|||
UPDATE "user" SET role = 'ADMIN' WHERE role = 'HOST';
|
||||
|
|
@ -42,14 +42,6 @@ CREATE TABLE memo (
|
|||
payload JSONB NOT NULL DEFAULT '{}'
|
||||
);
|
||||
|
||||
-- memo_organizer
|
||||
CREATE TABLE memo_organizer (
|
||||
memo_id INTEGER NOT NULL,
|
||||
user_id INTEGER NOT NULL,
|
||||
pinned INTEGER NOT NULL DEFAULT 0,
|
||||
UNIQUE(memo_id, user_id)
|
||||
);
|
||||
|
||||
-- memo_relation
|
||||
CREATE TABLE memo_relation (
|
||||
memo_id INTEGER NOT NULL,
|
||||
|
|
@ -58,8 +50,8 @@ CREATE TABLE memo_relation (
|
|||
UNIQUE(memo_id, related_memo_id, type)
|
||||
);
|
||||
|
||||
-- resource
|
||||
CREATE TABLE resource (
|
||||
-- attachment
|
||||
CREATE TABLE attachment (
|
||||
id SERIAL PRIMARY KEY,
|
||||
uid TEXT NOT NULL UNIQUE,
|
||||
creator_id INTEGER NOT NULL,
|
||||
|
|
|
|||
|
|
@ -0,0 +1,5 @@
|
|||
ALTER TABLE `resource` RENAME TO `attachment`;
|
||||
DROP INDEX IF EXISTS `idx_resource_creator_id`;
|
||||
CREATE INDEX `idx_attachment_creator_id` ON `attachment` (`creator_id`);
|
||||
DROP INDEX IF EXISTS `idx_resource_memo_id`;
|
||||
CREATE INDEX `idx_attachment_memo_id` ON `attachment` (`memo_id`);
|
||||
|
|
@ -0,0 +1 @@
|
|||
DROP TABLE IF EXISTS memo_organizer;
|
||||
|
|
@ -0,0 +1,4 @@
|
|||
DROP INDEX IF EXISTS idx_user_username;
|
||||
DROP INDEX IF EXISTS idx_memo_creator_id;
|
||||
DROP INDEX IF EXISTS idx_attachment_creator_id;
|
||||
DROP INDEX IF EXISTS idx_attachment_memo_id;
|
||||
|
|
@ -0,0 +1,24 @@
|
|||
ALTER TABLE user RENAME TO user_old;
|
||||
|
||||
CREATE TABLE user (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
created_ts BIGINT NOT NULL DEFAULT (strftime('%s', 'now')),
|
||||
updated_ts BIGINT NOT NULL DEFAULT (strftime('%s', 'now')),
|
||||
row_status TEXT NOT NULL CHECK (row_status IN ('NORMAL', 'ARCHIVED')) DEFAULT 'NORMAL',
|
||||
username TEXT NOT NULL UNIQUE,
|
||||
role TEXT NOT NULL DEFAULT 'USER',
|
||||
email TEXT NOT NULL DEFAULT '',
|
||||
nickname TEXT NOT NULL DEFAULT '',
|
||||
password_hash TEXT NOT NULL,
|
||||
avatar_url TEXT NOT NULL DEFAULT '',
|
||||
description TEXT NOT NULL DEFAULT ''
|
||||
);
|
||||
|
||||
INSERT INTO user (
|
||||
id, created_ts, updated_ts, row_status, username, role, email, nickname, password_hash, avatar_url, description
|
||||
)
|
||||
SELECT
|
||||
id, created_ts, updated_ts, row_status, username, role, email, nickname, password_hash, avatar_url, description
|
||||
FROM user_old;
|
||||
|
||||
DROP TABLE user_old;
|
||||
|
|
@ -0,0 +1 @@
|
|||
UPDATE user SET role = 'ADMIN' WHERE role = 'HOST';
|
||||
|
|
@ -13,7 +13,7 @@ CREATE TABLE user (
|
|||
updated_ts BIGINT NOT NULL DEFAULT (strftime('%s', 'now')),
|
||||
row_status TEXT NOT NULL CHECK (row_status IN ('NORMAL', 'ARCHIVED')) DEFAULT 'NORMAL',
|
||||
username TEXT NOT NULL UNIQUE,
|
||||
role TEXT NOT NULL CHECK (role IN ('HOST', 'ADMIN', 'USER')) DEFAULT 'USER',
|
||||
role TEXT NOT NULL DEFAULT 'USER',
|
||||
email TEXT NOT NULL DEFAULT '',
|
||||
nickname TEXT NOT NULL DEFAULT '',
|
||||
password_hash TEXT NOT NULL,
|
||||
|
|
@ -21,8 +21,6 @@ CREATE TABLE user (
|
|||
description TEXT NOT NULL DEFAULT ''
|
||||
);
|
||||
|
||||
CREATE INDEX idx_user_username ON user (username);
|
||||
|
||||
-- user_setting
|
||||
CREATE TABLE user_setting (
|
||||
user_id INTEGER NOT NULL,
|
||||
|
|
@ -45,16 +43,6 @@ CREATE TABLE memo (
|
|||
payload TEXT NOT NULL DEFAULT '{}'
|
||||
);
|
||||
|
||||
CREATE INDEX idx_memo_creator_id ON memo (creator_id);
|
||||
|
||||
-- memo_organizer
|
||||
CREATE TABLE memo_organizer (
|
||||
memo_id INTEGER NOT NULL,
|
||||
user_id INTEGER NOT NULL,
|
||||
pinned INTEGER NOT NULL CHECK (pinned IN (0, 1)) DEFAULT 0,
|
||||
UNIQUE(memo_id, user_id)
|
||||
);
|
||||
|
||||
-- memo_relation
|
||||
CREATE TABLE memo_relation (
|
||||
memo_id INTEGER NOT NULL,
|
||||
|
|
@ -63,8 +51,8 @@ CREATE TABLE memo_relation (
|
|||
UNIQUE(memo_id, related_memo_id, type)
|
||||
);
|
||||
|
||||
-- resource
|
||||
CREATE TABLE resource (
|
||||
-- attachment
|
||||
CREATE TABLE attachment (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
uid TEXT NOT NULL UNIQUE,
|
||||
creator_id INTEGER NOT NULL,
|
||||
|
|
@ -80,10 +68,6 @@ CREATE TABLE resource (
|
|||
payload TEXT NOT NULL DEFAULT '{}'
|
||||
);
|
||||
|
||||
CREATE INDEX idx_resource_creator_id ON resource (creator_id);
|
||||
|
||||
CREATE INDEX idx_resource_memo_id ON resource (memo_id);
|
||||
|
||||
-- activity
|
||||
CREATE TABLE activity (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
|
|
|
|||
|
|
@ -21,7 +21,7 @@ import (
|
|||
// Migration System Overview:
|
||||
//
|
||||
// The migration system handles database schema versioning and upgrades.
|
||||
// Schema version is stored in instance_setting (formerly system_setting).
|
||||
// Schema version is stored in system_setting.
|
||||
//
|
||||
// Migration Flow:
|
||||
// 1. preMigrate: Check if DB is initialized. If not, apply LATEST.sql
|
||||
|
|
@ -30,9 +30,9 @@ import (
|
|||
// 4. Migrate (demo mode): Seed database with demo data
|
||||
//
|
||||
// Version Tracking:
|
||||
// - New installations: Schema version set in instance_setting immediately
|
||||
// - Existing v0.22+ installations: Schema version tracked in instance_setting
|
||||
// - Pre-v0.22 installations: Must upgrade to v0.25.x first (migration_history → instance_setting migration)
|
||||
// - New installations: Schema version set in system_setting immediately
|
||||
// - Existing v0.22+ installations: Schema version tracked in system_setting
|
||||
// - Pre-v0.22 installations: Must upgrade to v0.25.x first (migration_history → system_setting migration)
|
||||
//
|
||||
// Migration Files:
|
||||
// - Location: store/migration/{driver}/{version}/NN__description.sql
|
||||
|
|
@ -57,10 +57,6 @@ const (
|
|||
// defaultSchemaVersion is used when schema version is empty or not set.
|
||||
// This handles edge cases for old installations without version tracking.
|
||||
defaultSchemaVersion = "0.0.0"
|
||||
|
||||
// Mode constants for profile mode.
|
||||
modeProd = "prod"
|
||||
modeDemo = "demo"
|
||||
)
|
||||
|
||||
// getSchemaVersionOrDefault returns the schema version or default if empty.
|
||||
|
|
@ -110,38 +106,36 @@ func (s *Store) Migrate(ctx context.Context) error {
|
|||
return errors.Wrap(err, "failed to pre-migrate")
|
||||
}
|
||||
|
||||
switch s.profile.Mode {
|
||||
case modeProd:
|
||||
instanceBasicSetting, err := s.GetInstanceBasicSetting(ctx)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "failed to get instance basic setting")
|
||||
instanceBasicSetting, err := s.GetInstanceBasicSetting(ctx)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "failed to get instance basic setting")
|
||||
}
|
||||
currentSchemaVersion, err := s.GetCurrentSchemaVersion()
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "failed to get current schema version")
|
||||
}
|
||||
// Check for downgrade (but skip if schema version is empty - that means fresh/old installation)
|
||||
if !isVersionEmpty(instanceBasicSetting.SchemaVersion) && version.IsVersionGreaterThan(instanceBasicSetting.SchemaVersion, currentSchemaVersion) {
|
||||
slog.Error("cannot downgrade schema version",
|
||||
slog.String("databaseVersion", instanceBasicSetting.SchemaVersion),
|
||||
slog.String("currentVersion", currentSchemaVersion),
|
||||
)
|
||||
return errors.Errorf("cannot downgrade schema version from %s to %s", instanceBasicSetting.SchemaVersion, currentSchemaVersion)
|
||||
}
|
||||
// Apply migrations if needed (including when schema version is empty)
|
||||
if isVersionEmpty(instanceBasicSetting.SchemaVersion) || version.IsVersionGreaterThan(currentSchemaVersion, instanceBasicSetting.SchemaVersion) {
|
||||
if err := s.applyMigrations(ctx, instanceBasicSetting.SchemaVersion, currentSchemaVersion); err != nil {
|
||||
return errors.Wrap(err, "failed to apply migrations")
|
||||
}
|
||||
currentSchemaVersion, err := s.GetCurrentSchemaVersion()
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "failed to get current schema version")
|
||||
}
|
||||
// Check for downgrade (but skip if schema version is empty - that means fresh/old installation)
|
||||
if !isVersionEmpty(instanceBasicSetting.SchemaVersion) && version.IsVersionGreaterThan(instanceBasicSetting.SchemaVersion, currentSchemaVersion) {
|
||||
slog.Error("cannot downgrade schema version",
|
||||
slog.String("databaseVersion", instanceBasicSetting.SchemaVersion),
|
||||
slog.String("currentVersion", currentSchemaVersion),
|
||||
)
|
||||
return errors.Errorf("cannot downgrade schema version from %s to %s", instanceBasicSetting.SchemaVersion, currentSchemaVersion)
|
||||
}
|
||||
// Apply migrations if needed (including when schema version is empty)
|
||||
if isVersionEmpty(instanceBasicSetting.SchemaVersion) || version.IsVersionGreaterThan(currentSchemaVersion, instanceBasicSetting.SchemaVersion) {
|
||||
if err := s.applyMigrations(ctx, instanceBasicSetting.SchemaVersion, currentSchemaVersion); err != nil {
|
||||
return errors.Wrap(err, "failed to apply migrations")
|
||||
}
|
||||
}
|
||||
case modeDemo:
|
||||
}
|
||||
|
||||
if s.profile.Demo {
|
||||
// In demo mode, we should seed the database.
|
||||
if err := s.seed(ctx); err != nil {
|
||||
return errors.Wrap(err, "failed to seed")
|
||||
}
|
||||
default:
|
||||
// For other modes (like dev), no special migration handling needed
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
|
|
@ -255,14 +249,11 @@ func (s *Store) preMigrate(ctx context.Context) error {
|
|||
}
|
||||
}
|
||||
|
||||
if s.profile.Mode == modeProd {
|
||||
if err := s.checkMinimumUpgradeVersion(ctx); err != nil {
|
||||
return err // Error message is already descriptive, don't wrap it
|
||||
}
|
||||
if err := s.checkMinimumUpgradeVersion(ctx); err != nil {
|
||||
return err // Error message is already descriptive, don't wrap it
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Store) getMigrationBasePath() string {
|
||||
return fmt.Sprintf("migration/%s/", s.profile.Driver)
|
||||
}
|
||||
|
|
@ -308,7 +299,7 @@ func (s *Store) seed(ctx context.Context) error {
|
|||
}
|
||||
|
||||
func (s *Store) GetCurrentSchemaVersion() (string, error) {
|
||||
currentVersion := version.GetCurrentVersion(s.profile.Mode)
|
||||
currentVersion := version.GetCurrentVersion()
|
||||
minorVersion := version.GetMinorVersion(currentVersion)
|
||||
filePaths, err := fs.Glob(migrationFS, fmt.Sprintf("%s%s/*.sql", s.getMigrationBasePath(), minorVersion))
|
||||
if err != nil {
|
||||
|
|
@ -373,7 +364,7 @@ func (s *Store) updateCurrentSchemaVersion(ctx context.Context, schemaVersion st
|
|||
|
||||
// checkMinimumUpgradeVersion verifies the installation meets minimum version requirements for upgrade.
|
||||
// For very old installations (< v0.22.0), users must upgrade to v0.25.x first before upgrading to current version.
|
||||
// This is necessary because schema version tracking was moved from migration_history to instance_setting in v0.22.0.
|
||||
// This is necessary because schema version tracking was moved from migration_history to system_setting in v0.22.0.
|
||||
func (s *Store) checkMinimumUpgradeVersion(ctx context.Context) error {
|
||||
instanceBasicSetting, err := s.GetInstanceBasicSetting(ctx)
|
||||
if err != nil {
|
||||
|
|
@ -401,7 +392,7 @@ func (s *Store) checkMinimumUpgradeVersion(ctx context.Context) error {
|
|||
"2. Start the server and verify it works\n"+
|
||||
"3. Then upgrade to the latest version\n\n"+
|
||||
"This is required because schema version tracking was moved from migration_history\n"+
|
||||
"to instance_setting in v0.22.0. The intermediate upgrade handles this migration safely.",
|
||||
"to system_setting in v0.22.0. The intermediate upgrade handles this migration safely.",
|
||||
schemaVersion,
|
||||
currentVersion,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -10,7 +10,7 @@ The demo data includes **6 carefully selected memos** that showcase the key feat
|
|||
|
||||
- **Username**: `demo`
|
||||
- **Password**: `secret` (default password)
|
||||
- **Role**: HOST
|
||||
- **Role**: ADMIN
|
||||
- **Nickname**: Demo User
|
||||
|
||||
## Demo Memos (6 total)
|
||||
|
|
@ -174,10 +174,10 @@ To run with demo data:
|
|||
|
||||
```bash
|
||||
# Start in demo mode
|
||||
go run ./cmd/memos --mode demo --port 8081
|
||||
go run ./cmd/memos --demo --port 8081
|
||||
|
||||
# Or use the binary
|
||||
./memos --mode demo
|
||||
./memos --demo
|
||||
|
||||
# Demo database location
|
||||
./build/memos_demo.db
|
||||
|
|
@ -198,7 +198,7 @@ Login with:
|
|||
|
||||
- All memos are set to PUBLIC visibility
|
||||
- **Two memos are pinned**: Welcome (#1) and Sponsor (#6)
|
||||
- User has HOST role to showcase all features
|
||||
- User has ADMIN role to showcase all features
|
||||
- Reactions are distributed across memos
|
||||
- One memo relation demonstrates linking
|
||||
- Content is optimized for the compact markdown styles
|
||||
|
|
|
|||
|
|
@ -1,11 +0,0 @@
|
|||
DELETE FROM system_setting;
|
||||
DELETE FROM user;
|
||||
DELETE FROM user_setting;
|
||||
DELETE FROM memo;
|
||||
DELETE FROM memo_organizer;
|
||||
DELETE FROM memo_relation;
|
||||
DELETE FROM resource;
|
||||
DELETE FROM activity;
|
||||
DELETE FROM idp;
|
||||
DELETE FROM inbox;
|
||||
DELETE FROM reaction;
|
||||
|
|
@ -1,5 +1,5 @@
|
|||
-- Demo User
|
||||
INSERT INTO user (id,username,role,nickname,password_hash) VALUES(1,'demo','HOST','Demo User','$2a$10$c.slEVgf5b/3BnAWlLb/vOu7VVSOKJ4ljwMe9xzlx9IhKnvAsJYM6');
|
||||
INSERT INTO user (id,username,role,nickname,password_hash) VALUES(1,'demo','ADMIN','Demo User','$2a$10$c.slEVgf5b/3BnAWlLb/vOu7VVSOKJ4ljwMe9xzlx9IhKnvAsJYM6');
|
||||
|
||||
-- Welcome Memo (Pinned)
|
||||
INSERT INTO memo (id,uid,creator_id,content,visibility,pinned,payload) VALUES(1,'welcome2memos001',1,replace('# Welcome to Memos!\\n\\nA privacy-first, lightweight note-taking service. Easily capture and share your great thoughts.\\n\\n## Key Features\\n\\n- **Privacy First**: Your data stays with you\\n- **Markdown Support**: Full CommonMark + GFM syntax\\n- **Quick Capture**: Jot down thoughts instantly\\n- **Organize with Tags**: Use #tags to categorize\\n- **Open Source**: Free and open source software\\n\\n---\\n\\nStart exploring the demo memos below to see what you can do! #welcome #getting-started','\\n',char(10)),'PUBLIC',1,'{"tags":["welcome","getting-started"],"property":{"hasLink":false}}');
|
||||
|
|
|
|||
|
|
@ -11,6 +11,7 @@ import (
|
|||
)
|
||||
|
||||
func TestActivityStore(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := context.Background()
|
||||
ts := NewTestingStore(ctx, t)
|
||||
user, err := createTestingHostUser(ctx, ts)
|
||||
|
|
@ -34,6 +35,7 @@ func TestActivityStore(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestActivityGetByID(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := context.Background()
|
||||
ts := NewTestingStore(ctx, t)
|
||||
user, err := createTestingHostUser(ctx, ts)
|
||||
|
|
@ -63,6 +65,7 @@ func TestActivityGetByID(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestActivityListMultiple(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := context.Background()
|
||||
ts := NewTestingStore(ctx, t)
|
||||
user, err := createTestingHostUser(ctx, ts)
|
||||
|
|
@ -99,3 +102,279 @@ func TestActivityListMultiple(t *testing.T) {
|
|||
|
||||
ts.Close()
|
||||
}
|
||||
|
||||
func TestActivityListByType(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := context.Background()
|
||||
ts := NewTestingStore(ctx, t)
|
||||
user, err := createTestingHostUser(ctx, ts)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create activities with MEMO_COMMENT type
|
||||
_, err = ts.CreateActivity(ctx, &store.Activity{
|
||||
CreatorID: user.ID,
|
||||
Type: store.ActivityTypeMemoComment,
|
||||
Level: store.ActivityLevelInfo,
|
||||
Payload: &storepb.ActivityPayload{},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = ts.CreateActivity(ctx, &store.Activity{
|
||||
CreatorID: user.ID,
|
||||
Type: store.ActivityTypeMemoComment,
|
||||
Level: store.ActivityLevelInfo,
|
||||
Payload: &storepb.ActivityPayload{},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// List by type
|
||||
activityType := store.ActivityTypeMemoComment
|
||||
activities, err := ts.ListActivities(ctx, &store.FindActivity{Type: &activityType})
|
||||
require.NoError(t, err)
|
||||
require.Len(t, activities, 2)
|
||||
for _, activity := range activities {
|
||||
require.Equal(t, store.ActivityTypeMemoComment, activity.Type)
|
||||
}
|
||||
|
||||
ts.Close()
|
||||
}
|
||||
|
||||
func TestActivityPayloadMemoComment(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := context.Background()
|
||||
ts := NewTestingStore(ctx, t)
|
||||
user, err := createTestingHostUser(ctx, ts)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create activity with MemoComment payload
|
||||
memoID := int32(123)
|
||||
relatedMemoID := int32(456)
|
||||
activity, err := ts.CreateActivity(ctx, &store.Activity{
|
||||
CreatorID: user.ID,
|
||||
Type: store.ActivityTypeMemoComment,
|
||||
Level: store.ActivityLevelInfo,
|
||||
Payload: &storepb.ActivityPayload{
|
||||
MemoComment: &storepb.ActivityMemoCommentPayload{
|
||||
MemoId: memoID,
|
||||
RelatedMemoId: relatedMemoID,
|
||||
},
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, activity.Payload)
|
||||
require.NotNil(t, activity.Payload.MemoComment)
|
||||
require.Equal(t, memoID, activity.Payload.MemoComment.MemoId)
|
||||
require.Equal(t, relatedMemoID, activity.Payload.MemoComment.RelatedMemoId)
|
||||
|
||||
// Verify payload is preserved when listing
|
||||
found, err := ts.GetActivity(ctx, &store.FindActivity{ID: &activity.ID})
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, found.Payload.MemoComment)
|
||||
require.Equal(t, memoID, found.Payload.MemoComment.MemoId)
|
||||
require.Equal(t, relatedMemoID, found.Payload.MemoComment.RelatedMemoId)
|
||||
|
||||
ts.Close()
|
||||
}
|
||||
|
||||
func TestActivityEmptyPayload(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := context.Background()
|
||||
ts := NewTestingStore(ctx, t)
|
||||
user, err := createTestingHostUser(ctx, ts)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create activity with empty payload
|
||||
activity, err := ts.CreateActivity(ctx, &store.Activity{
|
||||
CreatorID: user.ID,
|
||||
Type: store.ActivityTypeMemoComment,
|
||||
Level: store.ActivityLevelInfo,
|
||||
Payload: &storepb.ActivityPayload{},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, activity.Payload)
|
||||
|
||||
// Verify empty payload is handled correctly
|
||||
found, err := ts.GetActivity(ctx, &store.FindActivity{ID: &activity.ID})
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, found.Payload)
|
||||
require.Nil(t, found.Payload.MemoComment)
|
||||
|
||||
ts.Close()
|
||||
}
|
||||
|
||||
func TestActivityLevel(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := context.Background()
|
||||
ts := NewTestingStore(ctx, t)
|
||||
user, err := createTestingHostUser(ctx, ts)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create activity with INFO level
|
||||
activity, err := ts.CreateActivity(ctx, &store.Activity{
|
||||
CreatorID: user.ID,
|
||||
Type: store.ActivityTypeMemoComment,
|
||||
Level: store.ActivityLevelInfo,
|
||||
Payload: &storepb.ActivityPayload{},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, store.ActivityLevelInfo, activity.Level)
|
||||
|
||||
// Verify level is preserved when listing
|
||||
found, err := ts.GetActivity(ctx, &store.FindActivity{ID: &activity.ID})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, store.ActivityLevelInfo, found.Level)
|
||||
|
||||
ts.Close()
|
||||
}
|
||||
|
||||
func TestActivityCreatorID(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := context.Background()
|
||||
ts := NewTestingStore(ctx, t)
|
||||
user1, err := createTestingHostUser(ctx, ts)
|
||||
require.NoError(t, err)
|
||||
user2, err := createTestingUserWithRole(ctx, ts, "user2", store.RoleUser)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create activity for user1
|
||||
activity1, err := ts.CreateActivity(ctx, &store.Activity{
|
||||
CreatorID: user1.ID,
|
||||
Type: store.ActivityTypeMemoComment,
|
||||
Level: store.ActivityLevelInfo,
|
||||
Payload: &storepb.ActivityPayload{},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, user1.ID, activity1.CreatorID)
|
||||
|
||||
// Create activity for user2
|
||||
activity2, err := ts.CreateActivity(ctx, &store.Activity{
|
||||
CreatorID: user2.ID,
|
||||
Type: store.ActivityTypeMemoComment,
|
||||
Level: store.ActivityLevelInfo,
|
||||
Payload: &storepb.ActivityPayload{},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, user2.ID, activity2.CreatorID)
|
||||
|
||||
// List all and verify creator IDs
|
||||
activities, err := ts.ListActivities(ctx, &store.FindActivity{})
|
||||
require.NoError(t, err)
|
||||
require.Len(t, activities, 2)
|
||||
|
||||
ts.Close()
|
||||
}
|
||||
|
||||
func TestActivityCreatedTs(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := context.Background()
|
||||
ts := NewTestingStore(ctx, t)
|
||||
user, err := createTestingHostUser(ctx, ts)
|
||||
require.NoError(t, err)
|
||||
|
||||
activity, err := ts.CreateActivity(ctx, &store.Activity{
|
||||
CreatorID: user.ID,
|
||||
Type: store.ActivityTypeMemoComment,
|
||||
Level: store.ActivityLevelInfo,
|
||||
Payload: &storepb.ActivityPayload{},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.NotZero(t, activity.CreatedTs)
|
||||
|
||||
// Verify timestamp is preserved when listing
|
||||
found, err := ts.GetActivity(ctx, &store.FindActivity{ID: &activity.ID})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, activity.CreatedTs, found.CreatedTs)
|
||||
|
||||
ts.Close()
|
||||
}
|
||||
|
||||
func TestActivityListEmpty(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := context.Background()
|
||||
ts := NewTestingStore(ctx, t)
|
||||
|
||||
// List activities when none exist
|
||||
activities, err := ts.ListActivities(ctx, &store.FindActivity{})
|
||||
require.NoError(t, err)
|
||||
require.Len(t, activities, 0)
|
||||
|
||||
ts.Close()
|
||||
}
|
||||
|
||||
func TestActivityListWithIDAndType(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := context.Background()
|
||||
ts := NewTestingStore(ctx, t)
|
||||
user, err := createTestingHostUser(ctx, ts)
|
||||
require.NoError(t, err)
|
||||
|
||||
activity, err := ts.CreateActivity(ctx, &store.Activity{
|
||||
CreatorID: user.ID,
|
||||
Type: store.ActivityTypeMemoComment,
|
||||
Level: store.ActivityLevelInfo,
|
||||
Payload: &storepb.ActivityPayload{},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// List with both ID and Type filters
|
||||
activityType := store.ActivityTypeMemoComment
|
||||
activities, err := ts.ListActivities(ctx, &store.FindActivity{
|
||||
ID: &activity.ID,
|
||||
Type: &activityType,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Len(t, activities, 1)
|
||||
require.Equal(t, activity.ID, activities[0].ID)
|
||||
|
||||
ts.Close()
|
||||
}
|
||||
|
||||
func TestActivityPayloadComplexMemoComment(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := context.Background()
|
||||
ts := NewTestingStore(ctx, t)
|
||||
user, err := createTestingHostUser(ctx, ts)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create a memo first to use its ID
|
||||
memo, err := ts.CreateMemo(ctx, &store.Memo{
|
||||
UID: "test-memo-for-activity",
|
||||
CreatorID: user.ID,
|
||||
Content: "Test memo content",
|
||||
Visibility: store.Public,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create comment memo
|
||||
commentMemo, err := ts.CreateMemo(ctx, &store.Memo{
|
||||
UID: "comment-memo",
|
||||
CreatorID: user.ID,
|
||||
Content: "This is a comment",
|
||||
Visibility: store.Public,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create activity with real memo IDs
|
||||
activity, err := ts.CreateActivity(ctx, &store.Activity{
|
||||
CreatorID: user.ID,
|
||||
Type: store.ActivityTypeMemoComment,
|
||||
Level: store.ActivityLevelInfo,
|
||||
Payload: &storepb.ActivityPayload{
|
||||
MemoComment: &storepb.ActivityMemoCommentPayload{
|
||||
MemoId: memo.ID,
|
||||
RelatedMemoId: commentMemo.ID,
|
||||
},
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, memo.ID, activity.Payload.MemoComment.MemoId)
|
||||
require.Equal(t, commentMemo.ID, activity.Payload.MemoComment.RelatedMemoId)
|
||||
|
||||
// Verify payload is preserved
|
||||
found, err := ts.GetActivity(ctx, &store.FindActivity{ID: &activity.ID})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, memo.ID, found.Payload.MemoComment.MemoId)
|
||||
require.Equal(t, commentMemo.ID, found.Payload.MemoComment.RelatedMemoId)
|
||||
|
||||
ts.Close()
|
||||
}
|
||||
|
|
|
|||
|
|
@ -13,6 +13,7 @@ import (
|
|||
// =============================================================================
|
||||
|
||||
func TestAttachmentFilterFilenameContains(t *testing.T) {
|
||||
t.Parallel()
|
||||
tc := NewAttachmentFilterTestContext(t)
|
||||
defer tc.Close()
|
||||
|
||||
|
|
@ -35,6 +36,7 @@ func TestAttachmentFilterFilenameContains(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestAttachmentFilterFilenameSpecialCharacters(t *testing.T) {
|
||||
t.Parallel()
|
||||
tc := NewAttachmentFilterTestContext(t)
|
||||
defer tc.Close()
|
||||
|
||||
|
|
@ -51,6 +53,7 @@ func TestAttachmentFilterFilenameSpecialCharacters(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestAttachmentFilterFilenameUnicode(t *testing.T) {
|
||||
t.Parallel()
|
||||
tc := NewAttachmentFilterTestContext(t)
|
||||
defer tc.Close()
|
||||
|
||||
|
|
@ -67,6 +70,7 @@ func TestAttachmentFilterFilenameUnicode(t *testing.T) {
|
|||
// =============================================================================
|
||||
|
||||
func TestAttachmentFilterMimeTypeEquals(t *testing.T) {
|
||||
t.Parallel()
|
||||
tc := NewAttachmentFilterTestContext(t)
|
||||
defer tc.Close()
|
||||
|
||||
|
|
@ -86,6 +90,7 @@ func TestAttachmentFilterMimeTypeEquals(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestAttachmentFilterMimeTypeNotEquals(t *testing.T) {
|
||||
t.Parallel()
|
||||
tc := NewAttachmentFilterTestContext(t)
|
||||
defer tc.Close()
|
||||
|
||||
|
|
@ -98,6 +103,7 @@ func TestAttachmentFilterMimeTypeNotEquals(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestAttachmentFilterMimeTypeInList(t *testing.T) {
|
||||
t.Parallel()
|
||||
tc := NewAttachmentFilterTestContext(t)
|
||||
defer tc.Close()
|
||||
|
||||
|
|
@ -121,6 +127,7 @@ func TestAttachmentFilterMimeTypeInList(t *testing.T) {
|
|||
// =============================================================================
|
||||
|
||||
func TestAttachmentFilterCreateTimeComparison(t *testing.T) {
|
||||
t.Parallel()
|
||||
tc := NewAttachmentFilterTestContext(t)
|
||||
defer tc.Close()
|
||||
|
||||
|
|
@ -141,6 +148,7 @@ func TestAttachmentFilterCreateTimeComparison(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestAttachmentFilterCreateTimeWithNow(t *testing.T) {
|
||||
t.Parallel()
|
||||
tc := NewAttachmentFilterTestContext(t)
|
||||
defer tc.Close()
|
||||
|
||||
|
|
@ -156,6 +164,7 @@ func TestAttachmentFilterCreateTimeWithNow(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestAttachmentFilterCreateTimeArithmetic(t *testing.T) {
|
||||
t.Parallel()
|
||||
tc := NewAttachmentFilterTestContext(t)
|
||||
defer tc.Close()
|
||||
|
||||
|
|
@ -175,6 +184,7 @@ func TestAttachmentFilterCreateTimeArithmetic(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestAttachmentFilterAllComparisonOperators(t *testing.T) {
|
||||
t.Parallel()
|
||||
tc := NewAttachmentFilterTestContext(t)
|
||||
defer tc.Close()
|
||||
|
||||
|
|
@ -203,6 +213,7 @@ func TestAttachmentFilterAllComparisonOperators(t *testing.T) {
|
|||
// =============================================================================
|
||||
|
||||
func TestAttachmentFilterMemoIdEquals(t *testing.T) {
|
||||
t.Parallel()
|
||||
tc := NewAttachmentFilterTestContextWithUser(t)
|
||||
defer tc.Close()
|
||||
|
||||
|
|
@ -218,6 +229,7 @@ func TestAttachmentFilterMemoIdEquals(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestAttachmentFilterMemoIdNotEquals(t *testing.T) {
|
||||
t.Parallel()
|
||||
tc := NewAttachmentFilterTestContextWithUser(t)
|
||||
defer tc.Close()
|
||||
|
||||
|
|
@ -238,6 +250,7 @@ func TestAttachmentFilterMemoIdNotEquals(t *testing.T) {
|
|||
// =============================================================================
|
||||
|
||||
func TestAttachmentFilterLogicalAnd(t *testing.T) {
|
||||
t.Parallel()
|
||||
tc := NewAttachmentFilterTestContext(t)
|
||||
defer tc.Close()
|
||||
|
||||
|
|
@ -251,6 +264,7 @@ func TestAttachmentFilterLogicalAnd(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestAttachmentFilterLogicalOr(t *testing.T) {
|
||||
t.Parallel()
|
||||
tc := NewAttachmentFilterTestContext(t)
|
||||
defer tc.Close()
|
||||
|
||||
|
|
@ -263,6 +277,7 @@ func TestAttachmentFilterLogicalOr(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestAttachmentFilterLogicalNot(t *testing.T) {
|
||||
t.Parallel()
|
||||
tc := NewAttachmentFilterTestContext(t)
|
||||
defer tc.Close()
|
||||
|
||||
|
|
@ -275,6 +290,7 @@ func TestAttachmentFilterLogicalNot(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestAttachmentFilterComplexLogical(t *testing.T) {
|
||||
t.Parallel()
|
||||
tc := NewAttachmentFilterTestContext(t)
|
||||
defer tc.Close()
|
||||
|
||||
|
|
@ -291,6 +307,7 @@ func TestAttachmentFilterComplexLogical(t *testing.T) {
|
|||
// =============================================================================
|
||||
|
||||
func TestAttachmentFilterMultipleFilters(t *testing.T) {
|
||||
t.Parallel()
|
||||
tc := NewAttachmentFilterTestContext(t)
|
||||
defer tc.Close()
|
||||
|
||||
|
|
@ -310,6 +327,7 @@ func TestAttachmentFilterMultipleFilters(t *testing.T) {
|
|||
// =============================================================================
|
||||
|
||||
func TestAttachmentFilterNoMatches(t *testing.T) {
|
||||
t.Parallel()
|
||||
tc := NewAttachmentFilterTestContext(t)
|
||||
defer tc.Close()
|
||||
|
||||
|
|
@ -320,6 +338,7 @@ func TestAttachmentFilterNoMatches(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestAttachmentFilterNullMemoId(t *testing.T) {
|
||||
t.Parallel()
|
||||
tc := NewAttachmentFilterTestContextWithUser(t)
|
||||
defer tc.Close()
|
||||
|
||||
|
|
@ -343,6 +362,7 @@ func TestAttachmentFilterNullMemoId(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestAttachmentFilterEmptyFilename(t *testing.T) {
|
||||
t.Parallel()
|
||||
tc := NewAttachmentFilterTestContext(t)
|
||||
defer tc.Close()
|
||||
|
||||
|
|
|
|||
|
|
@ -12,6 +12,7 @@ import (
|
|||
)
|
||||
|
||||
func TestAttachmentStore(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := context.Background()
|
||||
ts := NewTestingStore(ctx, t)
|
||||
_, err := ts.CreateAttachment(ctx, &store.Attachment{
|
||||
|
|
@ -64,6 +65,7 @@ func TestAttachmentStore(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestAttachmentStoreWithFilter(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := context.Background()
|
||||
ts := NewTestingStore(ctx, t)
|
||||
|
||||
|
|
@ -123,6 +125,7 @@ func TestAttachmentStoreWithFilter(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestAttachmentUpdate(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := context.Background()
|
||||
ts := NewTestingStore(ctx, t)
|
||||
|
||||
|
|
@ -153,6 +156,7 @@ func TestAttachmentUpdate(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestAttachmentGetByUID(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := context.Background()
|
||||
ts := NewTestingStore(ctx, t)
|
||||
|
||||
|
|
@ -183,6 +187,7 @@ func TestAttachmentGetByUID(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestAttachmentListWithPagination(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := context.Background()
|
||||
ts := NewTestingStore(ctx, t)
|
||||
|
||||
|
|
@ -222,6 +227,7 @@ func TestAttachmentListWithPagination(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestAttachmentInvalidUID(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := context.Background()
|
||||
ts := NewTestingStore(ctx, t)
|
||||
|
||||
|
|
|
|||
|
|
@ -4,16 +4,19 @@ import (
|
|||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"os"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/docker/docker/api/types/container"
|
||||
"github.com/pkg/errors"
|
||||
"github.com/testcontainers/testcontainers-go"
|
||||
"github.com/testcontainers/testcontainers-go/modules/mysql"
|
||||
"github.com/testcontainers/testcontainers-go/modules/postgres"
|
||||
"github.com/testcontainers/testcontainers-go/network"
|
||||
"github.com/testcontainers/testcontainers-go/wait"
|
||||
|
||||
// Database drivers for connection verification.
|
||||
|
|
@ -24,23 +27,51 @@ import (
|
|||
const (
|
||||
testUser = "root"
|
||||
testPassword = "test"
|
||||
|
||||
// Memos container settings for migration testing.
|
||||
MemosDockerImage = "neosmemo/memos"
|
||||
StableMemosVersion = "stable" // Always points to the latest stable release
|
||||
)
|
||||
|
||||
var (
|
||||
mysqlContainer *mysql.MySQLContainer
|
||||
postgresContainer *postgres.PostgresContainer
|
||||
mysqlContainer atomic.Pointer[mysql.MySQLContainer]
|
||||
postgresContainer atomic.Pointer[postgres.PostgresContainer]
|
||||
mysqlOnce sync.Once
|
||||
postgresOnce sync.Once
|
||||
mysqlBaseDSN string
|
||||
postgresBaseDSN string
|
||||
mysqlBaseDSN atomic.Value // stores string
|
||||
postgresBaseDSN atomic.Value // stores string
|
||||
dbCounter atomic.Int64
|
||||
dbCreationMutex sync.Mutex // Protects database creation operations
|
||||
|
||||
// Network for container communication.
|
||||
testDockerNetwork atomic.Pointer[testcontainers.DockerNetwork]
|
||||
testNetworkOnce sync.Once
|
||||
)
|
||||
|
||||
// getTestNetwork creates or returns the shared Docker network for container communication.
|
||||
func getTestNetwork(ctx context.Context) (*testcontainers.DockerNetwork, error) {
|
||||
var networkErr error
|
||||
testNetworkOnce.Do(func() {
|
||||
nw, err := network.New(ctx, network.WithDriver("bridge"))
|
||||
if err != nil {
|
||||
networkErr = err
|
||||
return
|
||||
}
|
||||
testDockerNetwork.Store(nw)
|
||||
})
|
||||
return testDockerNetwork.Load(), networkErr
|
||||
}
|
||||
|
||||
// GetMySQLDSN starts a MySQL container (if not already running) and creates a fresh database for this test.
|
||||
func GetMySQLDSN(t *testing.T) string {
|
||||
ctx := context.Background()
|
||||
|
||||
mysqlOnce.Do(func() {
|
||||
nw, err := getTestNetwork(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create test network: %v", err)
|
||||
}
|
||||
|
||||
container, err := mysql.Run(ctx,
|
||||
"mysql:8",
|
||||
mysql.WithDatabase("init_db"),
|
||||
|
|
@ -55,11 +86,12 @@ func GetMySQLDSN(t *testing.T) string {
|
|||
wait.ForListeningPort("3306/tcp"),
|
||||
).WithDeadline(120*time.Second),
|
||||
),
|
||||
network.WithNetwork(nil, nw),
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to start MySQL container: %v", err)
|
||||
}
|
||||
mysqlContainer = container
|
||||
mysqlContainer.Store(container)
|
||||
|
||||
dsn, err := container.ConnectionString(ctx, "multiStatements=true")
|
||||
if err != nil {
|
||||
|
|
@ -70,16 +102,21 @@ func GetMySQLDSN(t *testing.T) string {
|
|||
t.Fatalf("MySQL not ready for connections: %v", err)
|
||||
}
|
||||
|
||||
mysqlBaseDSN = dsn
|
||||
mysqlBaseDSN.Store(dsn)
|
||||
})
|
||||
|
||||
if mysqlBaseDSN == "" {
|
||||
dsn, ok := mysqlBaseDSN.Load().(string)
|
||||
if !ok || dsn == "" {
|
||||
t.Fatal("MySQL container failed to start in a previous test")
|
||||
}
|
||||
|
||||
// Serialize database creation to avoid "table already exists" race conditions
|
||||
dbCreationMutex.Lock()
|
||||
defer dbCreationMutex.Unlock()
|
||||
|
||||
// Create a fresh database for this test
|
||||
dbName := fmt.Sprintf("memos_test_%d", dbCounter.Add(1))
|
||||
db, err := sql.Open("mysql", mysqlBaseDSN)
|
||||
db, err := sql.Open("mysql", dsn)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to connect to MySQL: %v", err)
|
||||
}
|
||||
|
|
@ -90,7 +127,7 @@ func GetMySQLDSN(t *testing.T) string {
|
|||
}
|
||||
|
||||
// Return DSN pointing to the new database
|
||||
return strings.Replace(mysqlBaseDSN, "/init_db?", "/"+dbName+"?", 1)
|
||||
return strings.Replace(dsn, "/init_db?", "/"+dbName+"?", 1)
|
||||
}
|
||||
|
||||
// waitForDB polls the database until it's ready or timeout is reached.
|
||||
|
|
@ -130,6 +167,11 @@ func GetPostgresDSN(t *testing.T) string {
|
|||
ctx := context.Background()
|
||||
|
||||
postgresOnce.Do(func() {
|
||||
nw, err := getTestNetwork(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create test network: %v", err)
|
||||
}
|
||||
|
||||
container, err := postgres.Run(ctx,
|
||||
"postgres:18",
|
||||
postgres.WithDatabase("init_db"),
|
||||
|
|
@ -141,11 +183,12 @@ func GetPostgresDSN(t *testing.T) string {
|
|||
wait.ForListeningPort("5432/tcp"),
|
||||
).WithDeadline(120*time.Second),
|
||||
),
|
||||
network.WithNetwork(nil, nw),
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to start PostgreSQL container: %v", err)
|
||||
}
|
||||
postgresContainer = container
|
||||
postgresContainer.Store(container)
|
||||
|
||||
dsn, err := container.ConnectionString(ctx, "sslmode=disable")
|
||||
if err != nil {
|
||||
|
|
@ -156,16 +199,21 @@ func GetPostgresDSN(t *testing.T) string {
|
|||
t.Fatalf("PostgreSQL not ready for connections: %v", err)
|
||||
}
|
||||
|
||||
postgresBaseDSN = dsn
|
||||
postgresBaseDSN.Store(dsn)
|
||||
})
|
||||
|
||||
if postgresBaseDSN == "" {
|
||||
dsn, ok := postgresBaseDSN.Load().(string)
|
||||
if !ok || dsn == "" {
|
||||
t.Fatal("PostgreSQL container failed to start in a previous test")
|
||||
}
|
||||
|
||||
// Serialize database creation to avoid "table already exists" race conditions
|
||||
dbCreationMutex.Lock()
|
||||
defer dbCreationMutex.Unlock()
|
||||
|
||||
// Create a fresh database for this test
|
||||
dbName := fmt.Sprintf("memos_test_%d", dbCounter.Add(1))
|
||||
db, err := sql.Open("postgres", postgresBaseDSN)
|
||||
db, err := sql.Open("postgres", dsn)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to connect to PostgreSQL: %v", err)
|
||||
}
|
||||
|
|
@ -176,17 +224,94 @@ func GetPostgresDSN(t *testing.T) string {
|
|||
}
|
||||
|
||||
// Return DSN pointing to the new database
|
||||
return strings.Replace(postgresBaseDSN, "/init_db?", "/"+dbName+"?", 1)
|
||||
return strings.Replace(dsn, "/init_db?", "/"+dbName+"?", 1)
|
||||
}
|
||||
|
||||
// TerminateContainers cleans up all running containers.
|
||||
// TerminateContainers cleans up all running containers and network.
|
||||
// This is typically called from TestMain.
|
||||
func TerminateContainers() {
|
||||
ctx := context.Background()
|
||||
if mysqlContainer != nil {
|
||||
_ = mysqlContainer.Terminate(ctx)
|
||||
if container := mysqlContainer.Load(); container != nil {
|
||||
_ = container.Terminate(ctx)
|
||||
}
|
||||
if postgresContainer != nil {
|
||||
_ = postgresContainer.Terminate(ctx)
|
||||
if container := postgresContainer.Load(); container != nil {
|
||||
_ = container.Terminate(ctx)
|
||||
}
|
||||
if network := testDockerNetwork.Load(); network != nil {
|
||||
_ = network.Remove(ctx)
|
||||
}
|
||||
}
|
||||
|
||||
// MemosContainerConfig holds configuration for starting a Memos container.
|
||||
type MemosContainerConfig struct {
|
||||
Version string // Memos version tag (e.g., "0.24.0")
|
||||
Driver string // Database driver: sqlite, mysql, postgres
|
||||
DSN string // Database DSN (for mysql/postgres)
|
||||
DataDir string // Host directory to mount for SQLite data
|
||||
}
|
||||
|
||||
// MemosStartupWaitStrategy defines the wait strategy for Memos container startup.
|
||||
// Uses regex to match various log message formats across versions.
|
||||
var MemosStartupWaitStrategy = wait.ForAll(
|
||||
wait.ForLog("(started successfully|has been started on port)").AsRegexp(),
|
||||
wait.ForListeningPort("5230/tcp"),
|
||||
).WithDeadline(180 * time.Second)
|
||||
|
||||
// StartMemosContainer starts a Memos container for migration testing.
|
||||
// For SQLite, it mounts the dataDir to /var/opt/memos.
|
||||
func StartMemosContainer(ctx context.Context, cfg MemosContainerConfig) (testcontainers.Container, error) {
|
||||
env := map[string]string{
|
||||
"MEMOS_MODE": "prod",
|
||||
}
|
||||
|
||||
var opts []testcontainers.ContainerCustomizer
|
||||
|
||||
switch cfg.Driver {
|
||||
case "sqlite":
|
||||
env["MEMOS_DRIVER"] = "sqlite"
|
||||
opts = append(opts, testcontainers.WithHostConfigModifier(func(hc *container.HostConfig) {
|
||||
hc.Binds = append(hc.Binds, fmt.Sprintf("%s:%s", cfg.DataDir, "/var/opt/memos"))
|
||||
}))
|
||||
default:
|
||||
return nil, errors.Errorf("unsupported driver for migration testing: %s", cfg.Driver)
|
||||
}
|
||||
|
||||
req := testcontainers.ContainerRequest{
|
||||
Image: fmt.Sprintf("%s:%s", MemosDockerImage, cfg.Version),
|
||||
Env: env,
|
||||
ExposedPorts: []string{"5230/tcp"},
|
||||
WaitingFor: MemosStartupWaitStrategy,
|
||||
User: fmt.Sprintf("%d:%d", os.Getuid(), os.Getgid()),
|
||||
}
|
||||
|
||||
// Use local image if specified
|
||||
if cfg.Version == "local" {
|
||||
if os.Getenv("MEMOS_TEST_IMAGE_BUILT") == "1" {
|
||||
req.Image = "memos-test:local"
|
||||
} else {
|
||||
req.FromDockerfile = testcontainers.FromDockerfile{
|
||||
Context: "../../",
|
||||
Dockerfile: "Dockerfile",
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
genericReq := testcontainers.GenericContainerRequest{
|
||||
ContainerRequest: req,
|
||||
Started: true,
|
||||
}
|
||||
|
||||
// Apply options
|
||||
for _, opt := range opts {
|
||||
if err := opt.Customize(&genericReq); err != nil {
|
||||
return nil, errors.Wrap(err, "failed to apply container option")
|
||||
}
|
||||
}
|
||||
|
||||
ctr, err := testcontainers.GenericContainer(ctx, genericReq)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "failed to start memos container")
|
||||
}
|
||||
|
||||
return ctr, nil
|
||||
}
|
||||
|
|
|
|||
|
|
@ -11,6 +11,7 @@ import (
|
|||
)
|
||||
|
||||
func TestIdentityProviderStore(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := context.Background()
|
||||
ts := NewTestingStore(ctx, t)
|
||||
createdIDP, err := ts.CreateIdentityProvider(ctx, &storepb.IdentityProvider{
|
||||
|
|
@ -58,3 +59,396 @@ func TestIdentityProviderStore(t *testing.T) {
|
|||
require.Equal(t, 0, len(idpList))
|
||||
ts.Close()
|
||||
}
|
||||
|
||||
func TestIdentityProviderGetByID(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := context.Background()
|
||||
ts := NewTestingStore(ctx, t)
|
||||
|
||||
// Create IDP
|
||||
idp, err := ts.CreateIdentityProvider(ctx, createTestOAuth2IDP("Test IDP"))
|
||||
require.NoError(t, err)
|
||||
|
||||
// Get by ID
|
||||
found, err := ts.GetIdentityProvider(ctx, &store.FindIdentityProvider{ID: &idp.Id})
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, found)
|
||||
require.Equal(t, idp.Id, found.Id)
|
||||
require.Equal(t, idp.Name, found.Name)
|
||||
|
||||
// Get by non-existent ID
|
||||
nonExistentID := int32(99999)
|
||||
notFound, err := ts.GetIdentityProvider(ctx, &store.FindIdentityProvider{ID: &nonExistentID})
|
||||
require.NoError(t, err)
|
||||
require.Nil(t, notFound)
|
||||
|
||||
ts.Close()
|
||||
}
|
||||
|
||||
func TestIdentityProviderListMultiple(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := context.Background()
|
||||
ts := NewTestingStore(ctx, t)
|
||||
|
||||
// Create multiple IDPs
|
||||
_, err := ts.CreateIdentityProvider(ctx, createTestOAuth2IDP("GitHub OAuth"))
|
||||
require.NoError(t, err)
|
||||
_, err = ts.CreateIdentityProvider(ctx, createTestOAuth2IDP("Google OAuth"))
|
||||
require.NoError(t, err)
|
||||
_, err = ts.CreateIdentityProvider(ctx, createTestOAuth2IDP("GitLab OAuth"))
|
||||
require.NoError(t, err)
|
||||
|
||||
// List all
|
||||
idpList, err := ts.ListIdentityProviders(ctx, &store.FindIdentityProvider{})
|
||||
require.NoError(t, err)
|
||||
require.Len(t, idpList, 3)
|
||||
|
||||
ts.Close()
|
||||
}
|
||||
|
||||
func TestIdentityProviderListByID(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := context.Background()
|
||||
ts := NewTestingStore(ctx, t)
|
||||
|
||||
// Create multiple IDPs
|
||||
idp1, err := ts.CreateIdentityProvider(ctx, createTestOAuth2IDP("GitHub OAuth"))
|
||||
require.NoError(t, err)
|
||||
_, err = ts.CreateIdentityProvider(ctx, createTestOAuth2IDP("Google OAuth"))
|
||||
require.NoError(t, err)
|
||||
|
||||
// List by specific ID
|
||||
idpList, err := ts.ListIdentityProviders(ctx, &store.FindIdentityProvider{ID: &idp1.Id})
|
||||
require.NoError(t, err)
|
||||
require.Len(t, idpList, 1)
|
||||
require.Equal(t, "GitHub OAuth", idpList[0].Name)
|
||||
|
||||
ts.Close()
|
||||
}
|
||||
|
||||
func TestIdentityProviderUpdateName(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := context.Background()
|
||||
ts := NewTestingStore(ctx, t)
|
||||
|
||||
idp, err := ts.CreateIdentityProvider(ctx, createTestOAuth2IDP("Original Name"))
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "Original Name", idp.Name)
|
||||
|
||||
// Update name
|
||||
newName := "Updated Name"
|
||||
updated, err := ts.UpdateIdentityProvider(ctx, &store.UpdateIdentityProviderV1{
|
||||
ID: idp.Id,
|
||||
Type: storepb.IdentityProvider_OAUTH2,
|
||||
Name: &newName,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "Updated Name", updated.Name)
|
||||
|
||||
// Verify update persisted
|
||||
found, err := ts.GetIdentityProvider(ctx, &store.FindIdentityProvider{ID: &idp.Id})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "Updated Name", found.Name)
|
||||
|
||||
ts.Close()
|
||||
}
|
||||
|
||||
func TestIdentityProviderUpdateIdentifierFilter(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := context.Background()
|
||||
ts := NewTestingStore(ctx, t)
|
||||
|
||||
idp, err := ts.CreateIdentityProvider(ctx, createTestOAuth2IDP("Test IDP"))
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "", idp.IdentifierFilter)
|
||||
|
||||
// Update identifier filter
|
||||
newFilter := "@example.com$"
|
||||
updated, err := ts.UpdateIdentityProvider(ctx, &store.UpdateIdentityProviderV1{
|
||||
ID: idp.Id,
|
||||
Type: storepb.IdentityProvider_OAUTH2,
|
||||
IdentifierFilter: &newFilter,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "@example.com$", updated.IdentifierFilter)
|
||||
|
||||
// Verify update persisted
|
||||
found, err := ts.GetIdentityProvider(ctx, &store.FindIdentityProvider{ID: &idp.Id})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "@example.com$", found.IdentifierFilter)
|
||||
|
||||
ts.Close()
|
||||
}
|
||||
|
||||
func TestIdentityProviderUpdateConfig(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := context.Background()
|
||||
ts := NewTestingStore(ctx, t)
|
||||
|
||||
idp, err := ts.CreateIdentityProvider(ctx, createTestOAuth2IDP("Test IDP"))
|
||||
require.NoError(t, err)
|
||||
|
||||
// Update config
|
||||
newConfig := &storepb.IdentityProviderConfig{
|
||||
Config: &storepb.IdentityProviderConfig_Oauth2Config{
|
||||
Oauth2Config: &storepb.OAuth2Config{
|
||||
ClientId: "new_client_id",
|
||||
ClientSecret: "new_client_secret",
|
||||
AuthUrl: "https://newprovider.com/auth",
|
||||
TokenUrl: "https://newprovider.com/token",
|
||||
UserInfoUrl: "https://newprovider.com/user",
|
||||
Scopes: []string{"openid", "profile", "email"},
|
||||
FieldMapping: &storepb.FieldMapping{
|
||||
Identifier: "sub",
|
||||
DisplayName: "name",
|
||||
Email: "email",
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
updated, err := ts.UpdateIdentityProvider(ctx, &store.UpdateIdentityProviderV1{
|
||||
ID: idp.Id,
|
||||
Type: storepb.IdentityProvider_OAUTH2,
|
||||
Config: newConfig,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "new_client_id", updated.Config.GetOauth2Config().ClientId)
|
||||
require.Equal(t, "new_client_secret", updated.Config.GetOauth2Config().ClientSecret)
|
||||
require.Contains(t, updated.Config.GetOauth2Config().Scopes, "openid")
|
||||
|
||||
// Verify update persisted
|
||||
found, err := ts.GetIdentityProvider(ctx, &store.FindIdentityProvider{ID: &idp.Id})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "new_client_id", found.Config.GetOauth2Config().ClientId)
|
||||
|
||||
ts.Close()
|
||||
}
|
||||
|
||||
func TestIdentityProviderUpdateMultipleFields(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := context.Background()
|
||||
ts := NewTestingStore(ctx, t)
|
||||
|
||||
idp, err := ts.CreateIdentityProvider(ctx, createTestOAuth2IDP("Original"))
|
||||
require.NoError(t, err)
|
||||
|
||||
// Update multiple fields at once
|
||||
newName := "Updated IDP"
|
||||
newFilter := "^admin@"
|
||||
updated, err := ts.UpdateIdentityProvider(ctx, &store.UpdateIdentityProviderV1{
|
||||
ID: idp.Id,
|
||||
Type: storepb.IdentityProvider_OAUTH2,
|
||||
Name: &newName,
|
||||
IdentifierFilter: &newFilter,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "Updated IDP", updated.Name)
|
||||
require.Equal(t, "^admin@", updated.IdentifierFilter)
|
||||
|
||||
ts.Close()
|
||||
}
|
||||
|
||||
func TestIdentityProviderDelete(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := context.Background()
|
||||
ts := NewTestingStore(ctx, t)
|
||||
|
||||
idp, err := ts.CreateIdentityProvider(ctx, createTestOAuth2IDP("Test IDP"))
|
||||
require.NoError(t, err)
|
||||
|
||||
// Delete
|
||||
err = ts.DeleteIdentityProvider(ctx, &store.DeleteIdentityProvider{ID: idp.Id})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify deletion
|
||||
found, err := ts.GetIdentityProvider(ctx, &store.FindIdentityProvider{ID: &idp.Id})
|
||||
require.NoError(t, err)
|
||||
require.Nil(t, found)
|
||||
|
||||
ts.Close()
|
||||
}
|
||||
|
||||
func TestIdentityProviderDeleteNotAffectOthers(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := context.Background()
|
||||
ts := NewTestingStore(ctx, t)
|
||||
|
||||
// Create multiple IDPs
|
||||
idp1, err := ts.CreateIdentityProvider(ctx, createTestOAuth2IDP("IDP 1"))
|
||||
require.NoError(t, err)
|
||||
idp2, err := ts.CreateIdentityProvider(ctx, createTestOAuth2IDP("IDP 2"))
|
||||
require.NoError(t, err)
|
||||
|
||||
// Delete first one
|
||||
err = ts.DeleteIdentityProvider(ctx, &store.DeleteIdentityProvider{ID: idp1.Id})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify second still exists
|
||||
found, err := ts.GetIdentityProvider(ctx, &store.FindIdentityProvider{ID: &idp2.Id})
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, found)
|
||||
require.Equal(t, "IDP 2", found.Name)
|
||||
|
||||
// Verify list only contains second
|
||||
idpList, err := ts.ListIdentityProviders(ctx, &store.FindIdentityProvider{})
|
||||
require.NoError(t, err)
|
||||
require.Len(t, idpList, 1)
|
||||
|
||||
ts.Close()
|
||||
}
|
||||
|
||||
func TestIdentityProviderOAuth2ConfigScopes(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := context.Background()
|
||||
ts := NewTestingStore(ctx, t)
|
||||
|
||||
// Create IDP with multiple scopes
|
||||
idp, err := ts.CreateIdentityProvider(ctx, &storepb.IdentityProvider{
|
||||
Name: "Multi-Scope OAuth",
|
||||
Type: storepb.IdentityProvider_OAUTH2,
|
||||
Config: &storepb.IdentityProviderConfig{
|
||||
Config: &storepb.IdentityProviderConfig_Oauth2Config{
|
||||
Oauth2Config: &storepb.OAuth2Config{
|
||||
ClientId: "client_id",
|
||||
ClientSecret: "client_secret",
|
||||
AuthUrl: "https://provider.com/auth",
|
||||
TokenUrl: "https://provider.com/token",
|
||||
UserInfoUrl: "https://provider.com/userinfo",
|
||||
Scopes: []string{"openid", "profile", "email", "groups"},
|
||||
FieldMapping: &storepb.FieldMapping{
|
||||
Identifier: "sub",
|
||||
DisplayName: "name",
|
||||
Email: "email",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify scopes are preserved
|
||||
found, err := ts.GetIdentityProvider(ctx, &store.FindIdentityProvider{ID: &idp.Id})
|
||||
require.NoError(t, err)
|
||||
require.Len(t, found.Config.GetOauth2Config().Scopes, 4)
|
||||
require.Contains(t, found.Config.GetOauth2Config().Scopes, "openid")
|
||||
require.Contains(t, found.Config.GetOauth2Config().Scopes, "groups")
|
||||
|
||||
ts.Close()
|
||||
}
|
||||
|
||||
func TestIdentityProviderFieldMapping(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := context.Background()
|
||||
ts := NewTestingStore(ctx, t)
|
||||
|
||||
// Create IDP with custom field mapping
|
||||
idp, err := ts.CreateIdentityProvider(ctx, &storepb.IdentityProvider{
|
||||
Name: "Custom Field Mapping",
|
||||
Type: storepb.IdentityProvider_OAUTH2,
|
||||
Config: &storepb.IdentityProviderConfig{
|
||||
Config: &storepb.IdentityProviderConfig_Oauth2Config{
|
||||
Oauth2Config: &storepb.OAuth2Config{
|
||||
ClientId: "client_id",
|
||||
ClientSecret: "client_secret",
|
||||
AuthUrl: "https://provider.com/auth",
|
||||
TokenUrl: "https://provider.com/token",
|
||||
UserInfoUrl: "https://provider.com/userinfo",
|
||||
Scopes: []string{"login"},
|
||||
FieldMapping: &storepb.FieldMapping{
|
||||
Identifier: "preferred_username",
|
||||
DisplayName: "full_name",
|
||||
Email: "email_address",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify field mapping is preserved
|
||||
found, err := ts.GetIdentityProvider(ctx, &store.FindIdentityProvider{ID: &idp.Id})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "preferred_username", found.Config.GetOauth2Config().FieldMapping.Identifier)
|
||||
require.Equal(t, "full_name", found.Config.GetOauth2Config().FieldMapping.DisplayName)
|
||||
require.Equal(t, "email_address", found.Config.GetOauth2Config().FieldMapping.Email)
|
||||
|
||||
ts.Close()
|
||||
}
|
||||
|
||||
func TestIdentityProviderIdentifierFilterPatterns(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := context.Background()
|
||||
ts := NewTestingStore(ctx, t)
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
filter string
|
||||
}{
|
||||
{"Domain filter", "@company\\.com$"},
|
||||
{"Prefix filter", "^admin_"},
|
||||
{"Complex regex", "^[a-z]+@(dept1|dept2)\\.example\\.com$"},
|
||||
{"Empty filter", ""},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
idp, err := ts.CreateIdentityProvider(ctx, &storepb.IdentityProvider{
|
||||
Name: tc.name,
|
||||
Type: storepb.IdentityProvider_OAUTH2,
|
||||
IdentifierFilter: tc.filter,
|
||||
Config: &storepb.IdentityProviderConfig{
|
||||
Config: &storepb.IdentityProviderConfig_Oauth2Config{
|
||||
Oauth2Config: &storepb.OAuth2Config{
|
||||
ClientId: "client_id",
|
||||
ClientSecret: "client_secret",
|
||||
AuthUrl: "https://provider.com/auth",
|
||||
TokenUrl: "https://provider.com/token",
|
||||
UserInfoUrl: "https://provider.com/userinfo",
|
||||
Scopes: []string{"login"},
|
||||
FieldMapping: &storepb.FieldMapping{
|
||||
Identifier: "sub",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
found, err := ts.GetIdentityProvider(ctx, &store.FindIdentityProvider{ID: &idp.Id})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, tc.filter, found.IdentifierFilter)
|
||||
|
||||
// Cleanup
|
||||
err = ts.DeleteIdentityProvider(ctx, &store.DeleteIdentityProvider{ID: idp.Id})
|
||||
require.NoError(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
ts.Close()
|
||||
}
|
||||
|
||||
// Helper function to create a test OAuth2 IDP.
|
||||
func createTestOAuth2IDP(name string) *storepb.IdentityProvider {
|
||||
return &storepb.IdentityProvider{
|
||||
Name: name,
|
||||
Type: storepb.IdentityProvider_OAUTH2,
|
||||
IdentifierFilter: "",
|
||||
Config: &storepb.IdentityProviderConfig{
|
||||
Config: &storepb.IdentityProviderConfig_Oauth2Config{
|
||||
Oauth2Config: &storepb.OAuth2Config{
|
||||
ClientId: "client_id",
|
||||
ClientSecret: "client_secret",
|
||||
AuthUrl: "https://provider.com/auth",
|
||||
TokenUrl: "https://provider.com/token",
|
||||
UserInfoUrl: "https://provider.com/userinfo",
|
||||
Scopes: []string{"login"},
|
||||
FieldMapping: &storepb.FieldMapping{
|
||||
Identifier: "login",
|
||||
DisplayName: "name",
|
||||
Email: "email",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -11,6 +11,7 @@ import (
|
|||
)
|
||||
|
||||
func TestInboxStore(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := context.Background()
|
||||
ts := NewTestingStore(ctx, t)
|
||||
user, err := createTestingHostUser(ctx, ts)
|
||||
|
|
@ -52,3 +53,540 @@ func TestInboxStore(t *testing.T) {
|
|||
require.Equal(t, 0, len(inboxes))
|
||||
ts.Close()
|
||||
}
|
||||
|
||||
func TestInboxListByID(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := context.Background()
|
||||
ts := NewTestingStore(ctx, t)
|
||||
user, err := createTestingHostUser(ctx, ts)
|
||||
require.NoError(t, err)
|
||||
|
||||
inbox, err := ts.CreateInbox(ctx, &store.Inbox{
|
||||
SenderID: 0,
|
||||
ReceiverID: user.ID,
|
||||
Status: store.UNREAD,
|
||||
Message: &storepb.InboxMessage{Type: storepb.InboxMessage_MEMO_COMMENT},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// List by ID
|
||||
inboxes, err := ts.ListInboxes(ctx, &store.FindInbox{ID: &inbox.ID})
|
||||
require.NoError(t, err)
|
||||
require.Len(t, inboxes, 1)
|
||||
require.Equal(t, inbox.ID, inboxes[0].ID)
|
||||
|
||||
// List by non-existent ID
|
||||
nonExistentID := int32(99999)
|
||||
inboxes, err = ts.ListInboxes(ctx, &store.FindInbox{ID: &nonExistentID})
|
||||
require.NoError(t, err)
|
||||
require.Len(t, inboxes, 0)
|
||||
|
||||
ts.Close()
|
||||
}
|
||||
|
||||
func TestInboxListBySenderID(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := context.Background()
|
||||
ts := NewTestingStore(ctx, t)
|
||||
user1, err := createTestingHostUser(ctx, ts)
|
||||
require.NoError(t, err)
|
||||
user2, err := createTestingUserWithRole(ctx, ts, "user2", store.RoleUser)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create inbox from system bot (senderID = 0)
|
||||
_, err = ts.CreateInbox(ctx, &store.Inbox{
|
||||
SenderID: 0,
|
||||
ReceiverID: user1.ID,
|
||||
Status: store.UNREAD,
|
||||
Message: &storepb.InboxMessage{Type: storepb.InboxMessage_MEMO_COMMENT},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create inbox from user2
|
||||
_, err = ts.CreateInbox(ctx, &store.Inbox{
|
||||
SenderID: user2.ID,
|
||||
ReceiverID: user1.ID,
|
||||
Status: store.UNREAD,
|
||||
Message: &storepb.InboxMessage{Type: storepb.InboxMessage_MEMO_COMMENT},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// List by sender ID = user2
|
||||
inboxes, err := ts.ListInboxes(ctx, &store.FindInbox{SenderID: &user2.ID})
|
||||
require.NoError(t, err)
|
||||
require.Len(t, inboxes, 1)
|
||||
require.Equal(t, user2.ID, inboxes[0].SenderID)
|
||||
|
||||
// List by sender ID = 0 (system bot)
|
||||
systemBotID := int32(0)
|
||||
inboxes, err = ts.ListInboxes(ctx, &store.FindInbox{SenderID: &systemBotID})
|
||||
require.NoError(t, err)
|
||||
require.Len(t, inboxes, 1)
|
||||
require.Equal(t, int32(0), inboxes[0].SenderID)
|
||||
|
||||
ts.Close()
|
||||
}
|
||||
|
||||
func TestInboxListByStatus(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := context.Background()
|
||||
ts := NewTestingStore(ctx, t)
|
||||
user, err := createTestingHostUser(ctx, ts)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create UNREAD inbox
|
||||
_, err = ts.CreateInbox(ctx, &store.Inbox{
|
||||
SenderID: 0,
|
||||
ReceiverID: user.ID,
|
||||
Status: store.UNREAD,
|
||||
Message: &storepb.InboxMessage{Type: storepb.InboxMessage_MEMO_COMMENT},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create another inbox and archive it
|
||||
inbox2, err := ts.CreateInbox(ctx, &store.Inbox{
|
||||
SenderID: 0,
|
||||
ReceiverID: user.ID,
|
||||
Status: store.UNREAD,
|
||||
Message: &storepb.InboxMessage{Type: storepb.InboxMessage_MEMO_COMMENT},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
_, err = ts.UpdateInbox(ctx, &store.UpdateInbox{ID: inbox2.ID, Status: store.ARCHIVED})
|
||||
require.NoError(t, err)
|
||||
|
||||
// List by UNREAD status
|
||||
unreadStatus := store.UNREAD
|
||||
inboxes, err := ts.ListInboxes(ctx, &store.FindInbox{Status: &unreadStatus})
|
||||
require.NoError(t, err)
|
||||
require.Len(t, inboxes, 1)
|
||||
require.Equal(t, store.UNREAD, inboxes[0].Status)
|
||||
|
||||
// List by ARCHIVED status
|
||||
archivedStatus := store.ARCHIVED
|
||||
inboxes, err = ts.ListInboxes(ctx, &store.FindInbox{Status: &archivedStatus})
|
||||
require.NoError(t, err)
|
||||
require.Len(t, inboxes, 1)
|
||||
require.Equal(t, store.ARCHIVED, inboxes[0].Status)
|
||||
|
||||
ts.Close()
|
||||
}
|
||||
|
||||
func TestInboxListByMessageType(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := context.Background()
|
||||
ts := NewTestingStore(ctx, t)
|
||||
user, err := createTestingHostUser(ctx, ts)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create MEMO_COMMENT inboxes
|
||||
_, err = ts.CreateInbox(ctx, &store.Inbox{
|
||||
SenderID: 0,
|
||||
ReceiverID: user.ID,
|
||||
Status: store.UNREAD,
|
||||
Message: &storepb.InboxMessage{Type: storepb.InboxMessage_MEMO_COMMENT},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = ts.CreateInbox(ctx, &store.Inbox{
|
||||
SenderID: 0,
|
||||
ReceiverID: user.ID,
|
||||
Status: store.UNREAD,
|
||||
Message: &storepb.InboxMessage{Type: storepb.InboxMessage_MEMO_COMMENT},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// List by MEMO_COMMENT type
|
||||
memoCommentType := storepb.InboxMessage_MEMO_COMMENT
|
||||
inboxes, err := ts.ListInboxes(ctx, &store.FindInbox{MessageType: &memoCommentType})
|
||||
require.NoError(t, err)
|
||||
require.Len(t, inboxes, 2)
|
||||
for _, inbox := range inboxes {
|
||||
require.Equal(t, storepb.InboxMessage_MEMO_COMMENT, inbox.Message.Type)
|
||||
}
|
||||
|
||||
ts.Close()
|
||||
}
|
||||
|
||||
func TestInboxListPagination(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := context.Background()
|
||||
ts := NewTestingStore(ctx, t)
|
||||
user, err := createTestingHostUser(ctx, ts)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create 5 inboxes
|
||||
for i := 0; i < 5; i++ {
|
||||
_, err = ts.CreateInbox(ctx, &store.Inbox{
|
||||
SenderID: 0,
|
||||
ReceiverID: user.ID,
|
||||
Status: store.UNREAD,
|
||||
Message: &storepb.InboxMessage{Type: storepb.InboxMessage_MEMO_COMMENT},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
// Test Limit only
|
||||
limit := 3
|
||||
inboxes, err := ts.ListInboxes(ctx, &store.FindInbox{
|
||||
ReceiverID: &user.ID,
|
||||
Limit: &limit,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Len(t, inboxes, 3)
|
||||
|
||||
// Test Limit + Offset (offset requires limit in the implementation)
|
||||
limit = 2
|
||||
offset := 2
|
||||
inboxes, err = ts.ListInboxes(ctx, &store.FindInbox{
|
||||
ReceiverID: &user.ID,
|
||||
Limit: &limit,
|
||||
Offset: &offset,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Len(t, inboxes, 2)
|
||||
|
||||
// Test Limit + Offset skipping to end
|
||||
limit = 10
|
||||
offset = 3
|
||||
inboxes, err = ts.ListInboxes(ctx, &store.FindInbox{
|
||||
ReceiverID: &user.ID,
|
||||
Limit: &limit,
|
||||
Offset: &offset,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Len(t, inboxes, 2) // Only 2 remaining after offset of 3
|
||||
|
||||
ts.Close()
|
||||
}
|
||||
|
||||
func TestInboxListCombinedFilters(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := context.Background()
|
||||
ts := NewTestingStore(ctx, t)
|
||||
user1, err := createTestingHostUser(ctx, ts)
|
||||
require.NoError(t, err)
|
||||
user2, err := createTestingUserWithRole(ctx, ts, "user2", store.RoleUser)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create various inboxes
|
||||
// user2 -> user1, MEMO_COMMENT, UNREAD
|
||||
_, err = ts.CreateInbox(ctx, &store.Inbox{
|
||||
SenderID: user2.ID,
|
||||
ReceiverID: user1.ID,
|
||||
Status: store.UNREAD,
|
||||
Message: &storepb.InboxMessage{Type: storepb.InboxMessage_MEMO_COMMENT},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// user2 -> user1, TYPE_UNSPECIFIED, UNREAD
|
||||
_, err = ts.CreateInbox(ctx, &store.Inbox{
|
||||
SenderID: user2.ID,
|
||||
ReceiverID: user1.ID,
|
||||
Status: store.UNREAD,
|
||||
Message: &storepb.InboxMessage{Type: storepb.InboxMessage_TYPE_UNSPECIFIED},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// system -> user1, MEMO_COMMENT, ARCHIVED
|
||||
inbox3, err := ts.CreateInbox(ctx, &store.Inbox{
|
||||
SenderID: 0,
|
||||
ReceiverID: user1.ID,
|
||||
Status: store.UNREAD,
|
||||
Message: &storepb.InboxMessage{Type: storepb.InboxMessage_MEMO_COMMENT},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
_, err = ts.UpdateInbox(ctx, &store.UpdateInbox{ID: inbox3.ID, Status: store.ARCHIVED})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Combined filter: ReceiverID + SenderID + Status
|
||||
unreadStatus := store.UNREAD
|
||||
inboxes, err := ts.ListInboxes(ctx, &store.FindInbox{
|
||||
ReceiverID: &user1.ID,
|
||||
SenderID: &user2.ID,
|
||||
Status: &unreadStatus,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Len(t, inboxes, 2)
|
||||
|
||||
// Combined filter: ReceiverID + MessageType + Status
|
||||
memoCommentType := storepb.InboxMessage_MEMO_COMMENT
|
||||
inboxes, err = ts.ListInboxes(ctx, &store.FindInbox{
|
||||
ReceiverID: &user1.ID,
|
||||
MessageType: &memoCommentType,
|
||||
Status: &unreadStatus,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Len(t, inboxes, 1)
|
||||
require.Equal(t, user2.ID, inboxes[0].SenderID)
|
||||
|
||||
ts.Close()
|
||||
}
|
||||
|
||||
func TestInboxMessagePayload(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := context.Background()
|
||||
ts := NewTestingStore(ctx, t)
|
||||
user, err := createTestingHostUser(ctx, ts)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create inbox with message payload containing activity ID
|
||||
activityID := int32(123)
|
||||
inbox, err := ts.CreateInbox(ctx, &store.Inbox{
|
||||
SenderID: 0,
|
||||
ReceiverID: user.ID,
|
||||
Status: store.UNREAD,
|
||||
Message: &storepb.InboxMessage{
|
||||
Type: storepb.InboxMessage_MEMO_COMMENT,
|
||||
ActivityId: &activityID,
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, inbox.Message)
|
||||
require.Equal(t, storepb.InboxMessage_MEMO_COMMENT, inbox.Message.Type)
|
||||
require.Equal(t, activityID, *inbox.Message.ActivityId)
|
||||
|
||||
// List and verify payload is preserved
|
||||
inboxes, err := ts.ListInboxes(ctx, &store.FindInbox{ReceiverID: &user.ID})
|
||||
require.NoError(t, err)
|
||||
require.Len(t, inboxes, 1)
|
||||
require.Equal(t, activityID, *inboxes[0].Message.ActivityId)
|
||||
|
||||
ts.Close()
|
||||
}
|
||||
|
||||
func TestInboxUpdateStatus(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := context.Background()
|
||||
ts := NewTestingStore(ctx, t)
|
||||
user, err := createTestingHostUser(ctx, ts)
|
||||
require.NoError(t, err)
|
||||
|
||||
inbox, err := ts.CreateInbox(ctx, &store.Inbox{
|
||||
SenderID: 0,
|
||||
ReceiverID: user.ID,
|
||||
Status: store.UNREAD,
|
||||
Message: &storepb.InboxMessage{Type: storepb.InboxMessage_MEMO_COMMENT},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, store.UNREAD, inbox.Status)
|
||||
|
||||
// Update to ARCHIVED
|
||||
updated, err := ts.UpdateInbox(ctx, &store.UpdateInbox{
|
||||
ID: inbox.ID,
|
||||
Status: store.ARCHIVED,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, store.ARCHIVED, updated.Status)
|
||||
require.Equal(t, inbox.ID, updated.ID)
|
||||
|
||||
// Verify the update persisted
|
||||
inboxes, err := ts.ListInboxes(ctx, &store.FindInbox{ID: &inbox.ID})
|
||||
require.NoError(t, err)
|
||||
require.Len(t, inboxes, 1)
|
||||
require.Equal(t, store.ARCHIVED, inboxes[0].Status)
|
||||
|
||||
ts.Close()
|
||||
}
|
||||
|
||||
func TestInboxListByMessageTypeMultipleTypes(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := context.Background()
|
||||
ts := NewTestingStore(ctx, t)
|
||||
user, err := createTestingHostUser(ctx, ts)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create inboxes with different message types
|
||||
_, err = ts.CreateInbox(ctx, &store.Inbox{
|
||||
SenderID: 0,
|
||||
ReceiverID: user.ID,
|
||||
Status: store.UNREAD,
|
||||
Message: &storepb.InboxMessage{Type: storepb.InboxMessage_MEMO_COMMENT},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = ts.CreateInbox(ctx, &store.Inbox{
|
||||
SenderID: 0,
|
||||
ReceiverID: user.ID,
|
||||
Status: store.UNREAD,
|
||||
Message: &storepb.InboxMessage{Type: storepb.InboxMessage_TYPE_UNSPECIFIED},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = ts.CreateInbox(ctx, &store.Inbox{
|
||||
SenderID: 0,
|
||||
ReceiverID: user.ID,
|
||||
Status: store.UNREAD,
|
||||
Message: &storepb.InboxMessage{Type: storepb.InboxMessage_MEMO_COMMENT},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Filter by MEMO_COMMENT - should get 2
|
||||
memoCommentType := storepb.InboxMessage_MEMO_COMMENT
|
||||
inboxes, err := ts.ListInboxes(ctx, &store.FindInbox{
|
||||
ReceiverID: &user.ID,
|
||||
MessageType: &memoCommentType,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Len(t, inboxes, 2)
|
||||
for _, inbox := range inboxes {
|
||||
require.Equal(t, storepb.InboxMessage_MEMO_COMMENT, inbox.Message.Type)
|
||||
}
|
||||
|
||||
// Filter by TYPE_UNSPECIFIED - should get 1
|
||||
unspecifiedType := storepb.InboxMessage_TYPE_UNSPECIFIED
|
||||
inboxes, err = ts.ListInboxes(ctx, &store.FindInbox{
|
||||
ReceiverID: &user.ID,
|
||||
MessageType: &unspecifiedType,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Len(t, inboxes, 1)
|
||||
require.Equal(t, storepb.InboxMessage_TYPE_UNSPECIFIED, inboxes[0].Message.Type)
|
||||
|
||||
ts.Close()
|
||||
}
|
||||
|
||||
func TestInboxMessageTypeFilterWithPayload(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := context.Background()
|
||||
ts := NewTestingStore(ctx, t)
|
||||
user, err := createTestingHostUser(ctx, ts)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create inbox with full payload
|
||||
activityID := int32(456)
|
||||
_, err = ts.CreateInbox(ctx, &store.Inbox{
|
||||
SenderID: 0,
|
||||
ReceiverID: user.ID,
|
||||
Status: store.UNREAD,
|
||||
Message: &storepb.InboxMessage{
|
||||
Type: storepb.InboxMessage_MEMO_COMMENT,
|
||||
ActivityId: &activityID,
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create inbox with different type but also has payload
|
||||
otherActivityID := int32(789)
|
||||
_, err = ts.CreateInbox(ctx, &store.Inbox{
|
||||
SenderID: 0,
|
||||
ReceiverID: user.ID,
|
||||
Status: store.UNREAD,
|
||||
Message: &storepb.InboxMessage{
|
||||
Type: storepb.InboxMessage_TYPE_UNSPECIFIED,
|
||||
ActivityId: &otherActivityID,
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Filter by type should work correctly even with complex JSON payload
|
||||
memoCommentType := storepb.InboxMessage_MEMO_COMMENT
|
||||
inboxes, err := ts.ListInboxes(ctx, &store.FindInbox{
|
||||
ReceiverID: &user.ID,
|
||||
MessageType: &memoCommentType,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Len(t, inboxes, 1)
|
||||
require.Equal(t, activityID, *inboxes[0].Message.ActivityId)
|
||||
|
||||
ts.Close()
|
||||
}
|
||||
|
||||
func TestInboxMessageTypeFilterWithStatusAndPagination(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := context.Background()
|
||||
ts := NewTestingStore(ctx, t)
|
||||
user, err := createTestingHostUser(ctx, ts)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create multiple inboxes with various combinations
|
||||
for i := 0; i < 5; i++ {
|
||||
_, err = ts.CreateInbox(ctx, &store.Inbox{
|
||||
SenderID: 0,
|
||||
ReceiverID: user.ID,
|
||||
Status: store.UNREAD,
|
||||
Message: &storepb.InboxMessage{Type: storepb.InboxMessage_MEMO_COMMENT},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
// Archive 2 of them
|
||||
allInboxes, err := ts.ListInboxes(ctx, &store.FindInbox{ReceiverID: &user.ID})
|
||||
require.NoError(t, err)
|
||||
for i := 0; i < 2; i++ {
|
||||
_, err = ts.UpdateInbox(ctx, &store.UpdateInbox{ID: allInboxes[i].ID, Status: store.ARCHIVED})
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
// Filter by type + status + pagination
|
||||
memoCommentType := storepb.InboxMessage_MEMO_COMMENT
|
||||
unreadStatus := store.UNREAD
|
||||
limit := 2
|
||||
inboxes, err := ts.ListInboxes(ctx, &store.FindInbox{
|
||||
ReceiverID: &user.ID,
|
||||
MessageType: &memoCommentType,
|
||||
Status: &unreadStatus,
|
||||
Limit: &limit,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Len(t, inboxes, 2)
|
||||
for _, inbox := range inboxes {
|
||||
require.Equal(t, storepb.InboxMessage_MEMO_COMMENT, inbox.Message.Type)
|
||||
require.Equal(t, store.UNREAD, inbox.Status)
|
||||
}
|
||||
|
||||
// Get next page
|
||||
offset := 2
|
||||
inboxes, err = ts.ListInboxes(ctx, &store.FindInbox{
|
||||
ReceiverID: &user.ID,
|
||||
MessageType: &memoCommentType,
|
||||
Status: &unreadStatus,
|
||||
Limit: &limit,
|
||||
Offset: &offset,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Len(t, inboxes, 1) // Only 1 remaining (3 unread total, got 2, now 1 left)
|
||||
|
||||
ts.Close()
|
||||
}
|
||||
|
||||
func TestInboxMultipleReceivers(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := context.Background()
|
||||
ts := NewTestingStore(ctx, t)
|
||||
user1, err := createTestingHostUser(ctx, ts)
|
||||
require.NoError(t, err)
|
||||
user2, err := createTestingUserWithRole(ctx, ts, "user2", store.RoleUser)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create inbox for user1
|
||||
_, err = ts.CreateInbox(ctx, &store.Inbox{
|
||||
SenderID: 0,
|
||||
ReceiverID: user1.ID,
|
||||
Status: store.UNREAD,
|
||||
Message: &storepb.InboxMessage{Type: storepb.InboxMessage_MEMO_COMMENT},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create inbox for user2
|
||||
_, err = ts.CreateInbox(ctx, &store.Inbox{
|
||||
SenderID: 0,
|
||||
ReceiverID: user2.ID,
|
||||
Status: store.UNREAD,
|
||||
Message: &storepb.InboxMessage{Type: storepb.InboxMessage_MEMO_COMMENT},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// User1 should only see their inbox
|
||||
inboxes, err := ts.ListInboxes(ctx, &store.FindInbox{ReceiverID: &user1.ID})
|
||||
require.NoError(t, err)
|
||||
require.Len(t, inboxes, 1)
|
||||
require.Equal(t, user1.ID, inboxes[0].ReceiverID)
|
||||
|
||||
// User2 should only see their inbox
|
||||
inboxes, err = ts.ListInboxes(ctx, &store.FindInbox{ReceiverID: &user2.ID})
|
||||
require.NoError(t, err)
|
||||
require.Len(t, inboxes, 1)
|
||||
require.Equal(t, user2.ID, inboxes[0].ReceiverID)
|
||||
|
||||
ts.Close()
|
||||
}
|
||||
|
|
|
|||
|
|
@ -11,6 +11,7 @@ import (
|
|||
)
|
||||
|
||||
func TestInstanceSettingV1Store(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := context.Background()
|
||||
ts := NewTestingStore(ctx, t)
|
||||
instanceSetting, err := ts.UpsertInstanceSetting(ctx, &storepb.InstanceSetting{
|
||||
|
|
@ -31,6 +32,7 @@ func TestInstanceSettingV1Store(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestInstanceSettingGetNonExistent(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := context.Background()
|
||||
ts := NewTestingStore(ctx, t)
|
||||
|
||||
|
|
@ -45,6 +47,7 @@ func TestInstanceSettingGetNonExistent(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestInstanceSettingUpsertUpdate(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := context.Background()
|
||||
ts := NewTestingStore(ctx, t)
|
||||
|
||||
|
|
@ -88,6 +91,7 @@ func TestInstanceSettingUpsertUpdate(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestInstanceSettingBasicSetting(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := context.Background()
|
||||
ts := NewTestingStore(ctx, t)
|
||||
|
||||
|
|
@ -116,6 +120,7 @@ func TestInstanceSettingBasicSetting(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestInstanceSettingGeneralSetting(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := context.Background()
|
||||
ts := NewTestingStore(ctx, t)
|
||||
|
||||
|
|
@ -146,6 +151,7 @@ func TestInstanceSettingGeneralSetting(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestInstanceSettingMemoRelatedSetting(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := context.Background()
|
||||
ts := NewTestingStore(ctx, t)
|
||||
|
||||
|
|
@ -179,6 +185,7 @@ func TestInstanceSettingMemoRelatedSetting(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestInstanceSettingStorageSetting(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := context.Background()
|
||||
ts := NewTestingStore(ctx, t)
|
||||
|
||||
|
|
@ -214,6 +221,7 @@ func TestInstanceSettingStorageSetting(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestInstanceSettingListAll(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := context.Background()
|
||||
ts := NewTestingStore(ctx, t)
|
||||
|
||||
|
|
@ -246,3 +254,47 @@ func TestInstanceSettingListAll(t *testing.T) {
|
|||
|
||||
ts.Close()
|
||||
}
|
||||
|
||||
func TestInstanceSettingEdgeCases(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := context.Background()
|
||||
ts := NewTestingStore(ctx, t)
|
||||
|
||||
// Case 1: General Setting with special characters and Unicode
|
||||
specialScript := `<script>alert("你好"); var x = 'test\'s';</script>`
|
||||
specialStyle := `body { font-family: "Noto Sans SC", sans-serif; content: "\u2764"; }`
|
||||
_, err := ts.UpsertInstanceSetting(ctx, &storepb.InstanceSetting{
|
||||
Key: storepb.InstanceSettingKey_GENERAL,
|
||||
Value: &storepb.InstanceSetting_GeneralSetting{
|
||||
GeneralSetting: &storepb.InstanceGeneralSetting{
|
||||
AdditionalScript: specialScript,
|
||||
AdditionalStyle: specialStyle,
|
||||
},
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
generalSetting, err := ts.GetInstanceGeneralSetting(ctx)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, specialScript, generalSetting.AdditionalScript)
|
||||
require.Equal(t, specialStyle, generalSetting.AdditionalStyle)
|
||||
|
||||
// Case 2: Memo Related Setting with Unicode reactions
|
||||
unicodeReactions := []string{"🐱", "🐶", "🦊", "🦄"}
|
||||
_, err = ts.UpsertInstanceSetting(ctx, &storepb.InstanceSetting{
|
||||
Key: storepb.InstanceSettingKey_MEMO_RELATED,
|
||||
Value: &storepb.InstanceSetting_MemoRelatedSetting{
|
||||
MemoRelatedSetting: &storepb.InstanceMemoRelatedSetting{
|
||||
ContentLengthLimit: 1000,
|
||||
Reactions: unicodeReactions,
|
||||
},
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
memoSetting, err := ts.GetInstanceMemoRelatedSetting(ctx)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, unicodeReactions, memoSetting.Reactions)
|
||||
|
||||
ts.Close()
|
||||
}
|
||||
|
|
|
|||
|
|
@ -13,7 +13,7 @@ func TestMain(m *testing.M) {
|
|||
// If DRIVER is set, run tests for that driver only
|
||||
if os.Getenv("DRIVER") != "" {
|
||||
defer TerminateContainers()
|
||||
m.Run()
|
||||
m.Run() //nolint:revive // Exit code is handled by test runner
|
||||
return
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -16,6 +16,7 @@ import (
|
|||
// =============================================================================
|
||||
|
||||
func TestMemoFilterContentContains(t *testing.T) {
|
||||
t.Parallel()
|
||||
tc := NewMemoFilterTestContext(t)
|
||||
defer tc.Close()
|
||||
|
||||
|
|
@ -39,6 +40,7 @@ func TestMemoFilterContentContains(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestMemoFilterContentSpecialCharacters(t *testing.T) {
|
||||
t.Parallel()
|
||||
tc := NewMemoFilterTestContext(t)
|
||||
defer tc.Close()
|
||||
|
||||
|
|
@ -49,6 +51,7 @@ func TestMemoFilterContentSpecialCharacters(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestMemoFilterContentUnicode(t *testing.T) {
|
||||
t.Parallel()
|
||||
tc := NewMemoFilterTestContext(t)
|
||||
defer tc.Close()
|
||||
|
||||
|
|
@ -58,12 +61,35 @@ func TestMemoFilterContentUnicode(t *testing.T) {
|
|||
require.Len(t, memos, 1)
|
||||
}
|
||||
|
||||
func TestMemoFilterContentCaseSensitivity(t *testing.T) {
|
||||
t.Parallel()
|
||||
tc := NewMemoFilterTestContext(t)
|
||||
defer tc.Close()
|
||||
|
||||
tc.CreateMemo(NewMemoBuilder("memo-case", tc.User.ID).Content("MixedCase Content"))
|
||||
|
||||
// Exact match
|
||||
memos := tc.ListWithFilter(`content.contains("MixedCase")`)
|
||||
require.Len(t, memos, 1)
|
||||
|
||||
// Lowercase match (depends on DB collation, usually case-insensitive in default installs but good to verify behavior)
|
||||
// SQLite default LIKE is case-insensitive for ASCII.
|
||||
memosLower := tc.ListWithFilter(`content.contains("mixedcase")`)
|
||||
// We just verify it doesn't crash; strict case sensitivity expectation depends on DB config.
|
||||
// For standard Memos setup (SQLite), it's often case-insensitive.
|
||||
// Let's check if we get a result or not to characterize current behavior.
|
||||
if len(memosLower) > 0 {
|
||||
require.Equal(t, "MixedCase Content", memosLower[0].Content)
|
||||
}
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// Visibility Field Tests
|
||||
// Schema: visibility (string, ==, !=)
|
||||
// =============================================================================
|
||||
|
||||
func TestMemoFilterVisibilityEquals(t *testing.T) {
|
||||
t.Parallel()
|
||||
tc := NewMemoFilterTestContext(t)
|
||||
defer tc.Close()
|
||||
|
||||
|
|
@ -83,6 +109,7 @@ func TestMemoFilterVisibilityEquals(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestMemoFilterVisibilityNotEquals(t *testing.T) {
|
||||
t.Parallel()
|
||||
tc := NewMemoFilterTestContext(t)
|
||||
defer tc.Close()
|
||||
|
||||
|
|
@ -95,6 +122,7 @@ func TestMemoFilterVisibilityNotEquals(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestMemoFilterVisibilityInList(t *testing.T) {
|
||||
t.Parallel()
|
||||
tc := NewMemoFilterTestContext(t)
|
||||
defer tc.Close()
|
||||
|
||||
|
|
@ -112,6 +140,7 @@ func TestMemoFilterVisibilityInList(t *testing.T) {
|
|||
// =============================================================================
|
||||
|
||||
func TestMemoFilterPinnedEquals(t *testing.T) {
|
||||
t.Parallel()
|
||||
tc := NewMemoFilterTestContext(t)
|
||||
defer tc.Close()
|
||||
|
||||
|
|
@ -131,6 +160,7 @@ func TestMemoFilterPinnedEquals(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestMemoFilterPinnedPredicate(t *testing.T) {
|
||||
t.Parallel()
|
||||
tc := NewMemoFilterTestContext(t)
|
||||
defer tc.Close()
|
||||
|
||||
|
|
@ -149,6 +179,7 @@ func TestMemoFilterPinnedPredicate(t *testing.T) {
|
|||
// =============================================================================
|
||||
|
||||
func TestMemoFilterCreatorIdEquals(t *testing.T) {
|
||||
t.Parallel()
|
||||
tc := NewMemoFilterTestContext(t)
|
||||
defer tc.Close()
|
||||
|
||||
|
|
@ -169,6 +200,7 @@ func TestMemoFilterCreatorIdEquals(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestMemoFilterCreatorIdNotEquals(t *testing.T) {
|
||||
t.Parallel()
|
||||
tc := NewMemoFilterTestContext(t)
|
||||
defer tc.Close()
|
||||
|
||||
|
|
@ -195,6 +227,7 @@ func TestMemoFilterCreatorIdNotEquals(t *testing.T) {
|
|||
// =============================================================================
|
||||
|
||||
func TestMemoFilterTagInList(t *testing.T) {
|
||||
t.Parallel()
|
||||
tc := NewMemoFilterTestContext(t)
|
||||
defer tc.Close()
|
||||
|
||||
|
|
@ -213,6 +246,7 @@ func TestMemoFilterTagInList(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestMemoFilterElementInTags(t *testing.T) {
|
||||
t.Parallel()
|
||||
tc := NewMemoFilterTestContext(t)
|
||||
defer tc.Close()
|
||||
|
||||
|
|
@ -229,6 +263,7 @@ func TestMemoFilterElementInTags(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestMemoFilterHierarchicalTags(t *testing.T) {
|
||||
t.Parallel()
|
||||
tc := NewMemoFilterTestContext(t)
|
||||
defer tc.Close()
|
||||
|
||||
|
|
@ -241,6 +276,7 @@ func TestMemoFilterHierarchicalTags(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestMemoFilterEmptyTags(t *testing.T) {
|
||||
t.Parallel()
|
||||
tc := NewMemoFilterTestContext(t)
|
||||
defer tc.Close()
|
||||
|
||||
|
|
@ -257,6 +293,7 @@ func TestMemoFilterEmptyTags(t *testing.T) {
|
|||
// =============================================================================
|
||||
|
||||
func TestMemoFilterHasTaskList(t *testing.T) {
|
||||
t.Parallel()
|
||||
tc := NewMemoFilterTestContext(t)
|
||||
defer tc.Close()
|
||||
|
||||
|
|
@ -279,6 +316,7 @@ func TestMemoFilterHasTaskList(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestMemoFilterHasLink(t *testing.T) {
|
||||
t.Parallel()
|
||||
tc := NewMemoFilterTestContext(t)
|
||||
defer tc.Close()
|
||||
|
||||
|
|
@ -293,6 +331,7 @@ func TestMemoFilterHasLink(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestMemoFilterHasCode(t *testing.T) {
|
||||
t.Parallel()
|
||||
tc := NewMemoFilterTestContext(t)
|
||||
defer tc.Close()
|
||||
|
||||
|
|
@ -307,6 +346,7 @@ func TestMemoFilterHasCode(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestMemoFilterHasIncompleteTasks(t *testing.T) {
|
||||
t.Parallel()
|
||||
tc := NewMemoFilterTestContext(t)
|
||||
defer tc.Close()
|
||||
|
||||
|
|
@ -329,6 +369,7 @@ func TestMemoFilterHasIncompleteTasks(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestMemoFilterCombinedJSONBool(t *testing.T) {
|
||||
t.Parallel()
|
||||
tc := NewMemoFilterTestContext(t)
|
||||
defer tc.Close()
|
||||
|
||||
|
|
@ -367,6 +408,7 @@ func TestMemoFilterCombinedJSONBool(t *testing.T) {
|
|||
// =============================================================================
|
||||
|
||||
func TestMemoFilterCreatedTsComparison(t *testing.T) {
|
||||
t.Parallel()
|
||||
tc := NewMemoFilterTestContext(t)
|
||||
defer tc.Close()
|
||||
|
||||
|
|
@ -387,6 +429,7 @@ func TestMemoFilterCreatedTsComparison(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestMemoFilterCreatedTsWithNow(t *testing.T) {
|
||||
t.Parallel()
|
||||
tc := NewMemoFilterTestContext(t)
|
||||
defer tc.Close()
|
||||
|
||||
|
|
@ -402,6 +445,7 @@ func TestMemoFilterCreatedTsWithNow(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestMemoFilterCreatedTsArithmetic(t *testing.T) {
|
||||
t.Parallel()
|
||||
tc := NewMemoFilterTestContext(t)
|
||||
defer tc.Close()
|
||||
|
||||
|
|
@ -421,6 +465,7 @@ func TestMemoFilterCreatedTsArithmetic(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestMemoFilterUpdatedTs(t *testing.T) {
|
||||
t.Parallel()
|
||||
tc := NewMemoFilterTestContext(t)
|
||||
defer tc.Close()
|
||||
|
||||
|
|
@ -444,6 +489,7 @@ func TestMemoFilterUpdatedTs(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestMemoFilterAllComparisonOperators(t *testing.T) {
|
||||
t.Parallel()
|
||||
tc := NewMemoFilterTestContext(t)
|
||||
defer tc.Close()
|
||||
|
||||
|
|
@ -472,6 +518,7 @@ func TestMemoFilterAllComparisonOperators(t *testing.T) {
|
|||
// =============================================================================
|
||||
|
||||
func TestMemoFilterLogicalAnd(t *testing.T) {
|
||||
t.Parallel()
|
||||
tc := NewMemoFilterTestContext(t)
|
||||
defer tc.Close()
|
||||
|
||||
|
|
@ -485,6 +532,7 @@ func TestMemoFilterLogicalAnd(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestMemoFilterLogicalOr(t *testing.T) {
|
||||
t.Parallel()
|
||||
tc := NewMemoFilterTestContext(t)
|
||||
defer tc.Close()
|
||||
|
||||
|
|
@ -497,6 +545,7 @@ func TestMemoFilterLogicalOr(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestMemoFilterLogicalNot(t *testing.T) {
|
||||
t.Parallel()
|
||||
tc := NewMemoFilterTestContext(t)
|
||||
defer tc.Close()
|
||||
|
||||
|
|
@ -510,6 +559,7 @@ func TestMemoFilterLogicalNot(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestMemoFilterNegatedComparison(t *testing.T) {
|
||||
t.Parallel()
|
||||
tc := NewMemoFilterTestContext(t)
|
||||
defer tc.Close()
|
||||
|
||||
|
|
@ -522,6 +572,7 @@ func TestMemoFilterNegatedComparison(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestMemoFilterComplexLogical(t *testing.T) {
|
||||
t.Parallel()
|
||||
tc := NewMemoFilterTestContext(t)
|
||||
defer tc.Close()
|
||||
|
||||
|
|
@ -545,6 +596,244 @@ func TestMemoFilterComplexLogical(t *testing.T) {
|
|||
// Test: (pinned || tag in ["important"]) && visibility == "PUBLIC"
|
||||
memos = tc.ListWithFilter(`(pinned || tag in ["important"]) && visibility == "PUBLIC"`)
|
||||
require.Len(t, memos, 3)
|
||||
|
||||
// Test: De Morgan's Law ! (A || B) == !A && !B
|
||||
// ! (pinned || has_task_list)
|
||||
tc.CreateMemo(NewMemoBuilder("memo-no-props", tc.User.ID).Content("Nothing special"))
|
||||
memos = tc.ListWithFilter(`!(pinned || has_task_list)`)
|
||||
require.Len(t, memos, 2) // Unpinned-tagged + Nothing special (pinned-untagged is pinned)
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// Tag Comprehension Tests (exists macro)
|
||||
// Schema: tags (list of strings, supports exists/all macros with predicates)
|
||||
// =============================================================================
|
||||
|
||||
func TestMemoFilterTagsExistsStartsWith(t *testing.T) {
|
||||
t.Parallel()
|
||||
tc := NewMemoFilterTestContext(t)
|
||||
defer tc.Close()
|
||||
|
||||
// Create memos with different tags
|
||||
tc.CreateMemo(NewMemoBuilder("memo-archive1", tc.User.ID).
|
||||
Content("Archived project memo").
|
||||
Tags("archive/project", "done"))
|
||||
|
||||
tc.CreateMemo(NewMemoBuilder("memo-archive2", tc.User.ID).
|
||||
Content("Archived work memo").
|
||||
Tags("archive/work", "old"))
|
||||
|
||||
tc.CreateMemo(NewMemoBuilder("memo-active", tc.User.ID).
|
||||
Content("Active project memo").
|
||||
Tags("project/active", "todo"))
|
||||
|
||||
tc.CreateMemo(NewMemoBuilder("memo-homelab", tc.User.ID).
|
||||
Content("Homelab memo").
|
||||
Tags("homelab/memos", "tech"))
|
||||
|
||||
// Test: tags.exists(t, t.startsWith("archive")) - should match archived memos
|
||||
memos := tc.ListWithFilter(`tags.exists(t, t.startsWith("archive"))`)
|
||||
require.Len(t, memos, 2, "Should find 2 archived memos")
|
||||
for _, memo := range memos {
|
||||
hasArchiveTag := false
|
||||
for _, tag := range memo.Payload.Tags {
|
||||
if len(tag) >= 7 && tag[:7] == "archive" {
|
||||
hasArchiveTag = true
|
||||
break
|
||||
}
|
||||
}
|
||||
require.True(t, hasArchiveTag, "Memo should have tag starting with 'archive'")
|
||||
}
|
||||
|
||||
// Test: !tags.exists(t, t.startsWith("archive")) - should match non-archived memos
|
||||
memos = tc.ListWithFilter(`!tags.exists(t, t.startsWith("archive"))`)
|
||||
require.Len(t, memos, 2, "Should find 2 non-archived memos")
|
||||
|
||||
// Test: tags.exists(t, t.startsWith("project")) - should match project memos
|
||||
memos = tc.ListWithFilter(`tags.exists(t, t.startsWith("project"))`)
|
||||
require.Len(t, memos, 1, "Should find 1 project memo")
|
||||
|
||||
// Test: tags.exists(t, t.startsWith("homelab")) - should match homelab memos
|
||||
memos = tc.ListWithFilter(`tags.exists(t, t.startsWith("homelab"))`)
|
||||
require.Len(t, memos, 1, "Should find 1 homelab memo")
|
||||
|
||||
// Test: tags.exists(t, t.startsWith("nonexistent")) - should match nothing
|
||||
memos = tc.ListWithFilter(`tags.exists(t, t.startsWith("nonexistent"))`)
|
||||
require.Len(t, memos, 0, "Should find no memos")
|
||||
}
|
||||
|
||||
func TestMemoFilterTagsExistsContains(t *testing.T) {
|
||||
t.Parallel()
|
||||
tc := NewMemoFilterTestContext(t)
|
||||
defer tc.Close()
|
||||
|
||||
// Create memos with different tags
|
||||
tc.CreateMemo(NewMemoBuilder("memo-todo1", tc.User.ID).
|
||||
Content("Todo task 1").
|
||||
Tags("project/todo", "urgent"))
|
||||
|
||||
tc.CreateMemo(NewMemoBuilder("memo-todo2", tc.User.ID).
|
||||
Content("Todo task 2").
|
||||
Tags("work/todo-list", "pending"))
|
||||
|
||||
tc.CreateMemo(NewMemoBuilder("memo-done", tc.User.ID).
|
||||
Content("Done task").
|
||||
Tags("project/completed", "done"))
|
||||
|
||||
// Test: tags.exists(t, t.contains("todo")) - should match todos
|
||||
memos := tc.ListWithFilter(`tags.exists(t, t.contains("todo"))`)
|
||||
require.Len(t, memos, 2, "Should find 2 todo memos")
|
||||
|
||||
// Test: tags.exists(t, t.contains("done")) - should match done
|
||||
memos = tc.ListWithFilter(`tags.exists(t, t.contains("done"))`)
|
||||
require.Len(t, memos, 1, "Should find 1 done memo")
|
||||
|
||||
// Test: !tags.exists(t, t.contains("todo")) - should exclude todos
|
||||
memos = tc.ListWithFilter(`!tags.exists(t, t.contains("todo"))`)
|
||||
require.Len(t, memos, 1, "Should find 1 non-todo memo")
|
||||
}
|
||||
|
||||
func TestMemoFilterTagsExistsEndsWith(t *testing.T) {
|
||||
t.Parallel()
|
||||
tc := NewMemoFilterTestContext(t)
|
||||
defer tc.Close()
|
||||
|
||||
// Create memos with different tag endings
|
||||
tc.CreateMemo(NewMemoBuilder("memo-bug", tc.User.ID).
|
||||
Content("Bug report").
|
||||
Tags("project/bug", "critical"))
|
||||
|
||||
tc.CreateMemo(NewMemoBuilder("memo-debug", tc.User.ID).
|
||||
Content("Debug session").
|
||||
Tags("work/debug", "dev"))
|
||||
|
||||
tc.CreateMemo(NewMemoBuilder("memo-feature", tc.User.ID).
|
||||
Content("New feature").
|
||||
Tags("project/feature", "new"))
|
||||
|
||||
// Test: tags.exists(t, t.endsWith("bug")) - should match bug-related tags
|
||||
memos := tc.ListWithFilter(`tags.exists(t, t.endsWith("bug"))`)
|
||||
require.Len(t, memos, 2, "Should find 2 bug-related memos")
|
||||
|
||||
// Test: tags.exists(t, t.endsWith("feature")) - should match feature
|
||||
memos = tc.ListWithFilter(`tags.exists(t, t.endsWith("feature"))`)
|
||||
require.Len(t, memos, 1, "Should find 1 feature memo")
|
||||
|
||||
// Test: !tags.exists(t, t.endsWith("bug")) - should exclude bug-related
|
||||
memos = tc.ListWithFilter(`!tags.exists(t, t.endsWith("bug"))`)
|
||||
require.Len(t, memos, 1, "Should find 1 non-bug memo")
|
||||
}
|
||||
|
||||
func TestMemoFilterTagsExistsCombinedWithOtherFilters(t *testing.T) {
|
||||
t.Parallel()
|
||||
tc := NewMemoFilterTestContext(t)
|
||||
defer tc.Close()
|
||||
|
||||
// Create memos with tags and other properties
|
||||
tc.CreateMemo(NewMemoBuilder("memo-archived-old", tc.User.ID).
|
||||
Content("Old archived memo").
|
||||
Tags("archive/old", "done"))
|
||||
|
||||
tc.CreateMemo(NewMemoBuilder("memo-archived-recent", tc.User.ID).
|
||||
Content("Recent archived memo with TODO").
|
||||
Tags("archive/recent", "done"))
|
||||
|
||||
tc.CreateMemo(NewMemoBuilder("memo-active-todo", tc.User.ID).
|
||||
Content("Active TODO").
|
||||
Tags("project/active", "todo"))
|
||||
|
||||
// Test: Combine tag filter with content filter
|
||||
memos := tc.ListWithFilter(`tags.exists(t, t.startsWith("archive")) && content.contains("TODO")`)
|
||||
require.Len(t, memos, 1, "Should find 1 archived memo with TODO in content")
|
||||
|
||||
// Test: OR condition with tag filters
|
||||
memos = tc.ListWithFilter(`tags.exists(t, t.startsWith("archive")) || tags.exists(t, t.contains("todo"))`)
|
||||
require.Len(t, memos, 3, "Should find all memos (archived or with todo tag)")
|
||||
|
||||
// Test: Complex filter - archived but not containing "Recent"
|
||||
memos = tc.ListWithFilter(`tags.exists(t, t.startsWith("archive")) && !content.contains("Recent")`)
|
||||
require.Len(t, memos, 1, "Should find 1 old archived memo")
|
||||
}
|
||||
|
||||
func TestMemoFilterTagsExistsEmptyAndNullCases(t *testing.T) {
|
||||
t.Parallel()
|
||||
tc := NewMemoFilterTestContext(t)
|
||||
defer tc.Close()
|
||||
|
||||
// Create memo with no tags
|
||||
tc.CreateMemo(NewMemoBuilder("memo-no-tags", tc.User.ID).
|
||||
Content("Memo without tags"))
|
||||
|
||||
// Create memo with tags
|
||||
tc.CreateMemo(NewMemoBuilder("memo-with-tags", tc.User.ID).
|
||||
Content("Memo with tags").
|
||||
Tags("tag1", "tag2"))
|
||||
|
||||
// Test: tags.exists should not match memos without tags
|
||||
memos := tc.ListWithFilter(`tags.exists(t, t.startsWith("tag"))`)
|
||||
require.Len(t, memos, 1, "Should only find memo with tags")
|
||||
|
||||
// Test: Negation should match memos without matching tags
|
||||
memos = tc.ListWithFilter(`!tags.exists(t, t.startsWith("tag"))`)
|
||||
require.Len(t, memos, 1, "Should find memo without matching tags")
|
||||
}
|
||||
|
||||
func TestMemoFilterIssue5480_ArchiveWorkflow(t *testing.T) {
|
||||
t.Parallel()
|
||||
tc := NewMemoFilterTestContext(t)
|
||||
defer tc.Close()
|
||||
|
||||
// Create a realistic scenario as described in issue #5480
|
||||
// User has hierarchical tags and archives memos by prefixing with "archive"
|
||||
|
||||
// Active memos
|
||||
tc.CreateMemo(NewMemoBuilder("memo-homelab", tc.User.ID).
|
||||
Content("Setting up Memos").
|
||||
Tags("homelab/memos", "tech"))
|
||||
|
||||
tc.CreateMemo(NewMemoBuilder("memo-project-alpha", tc.User.ID).
|
||||
Content("Project Alpha notes").
|
||||
Tags("work/project-alpha", "active"))
|
||||
|
||||
// Archived memos (user prefixed tags with "archive")
|
||||
tc.CreateMemo(NewMemoBuilder("memo-old-homelab", tc.User.ID).
|
||||
Content("Old homelab setup").
|
||||
Tags("archive/homelab/old-server", "done"))
|
||||
|
||||
tc.CreateMemo(NewMemoBuilder("memo-old-project", tc.User.ID).
|
||||
Content("Old project beta").
|
||||
Tags("archive/work/project-beta", "completed"))
|
||||
|
||||
tc.CreateMemo(NewMemoBuilder("memo-archived-personal", tc.User.ID).
|
||||
Content("Archived personal note").
|
||||
Tags("archive/personal/2024", "old"))
|
||||
|
||||
// Test: Filter out ALL archived memos using startsWith
|
||||
memos := tc.ListWithFilter(`!tags.exists(t, t.startsWith("archive"))`)
|
||||
require.Len(t, memos, 2, "Should only show active memos (not archived)")
|
||||
for _, memo := range memos {
|
||||
for _, tag := range memo.Payload.Tags {
|
||||
require.NotContains(t, tag, "archive", "Active memos should not have archive prefix")
|
||||
}
|
||||
}
|
||||
|
||||
// Test: Show ONLY archived memos
|
||||
memos = tc.ListWithFilter(`tags.exists(t, t.startsWith("archive"))`)
|
||||
require.Len(t, memos, 3, "Should find all archived memos")
|
||||
for _, memo := range memos {
|
||||
hasArchiveTag := false
|
||||
for _, tag := range memo.Payload.Tags {
|
||||
if len(tag) >= 7 && tag[:7] == "archive" {
|
||||
hasArchiveTag = true
|
||||
break
|
||||
}
|
||||
}
|
||||
require.True(t, hasArchiveTag, "All returned memos should have archive prefix")
|
||||
}
|
||||
|
||||
// Test: Filter archived homelab memos specifically
|
||||
memos = tc.ListWithFilter(`tags.exists(t, t.startsWith("archive/homelab"))`)
|
||||
require.Len(t, memos, 1, "Should find only archived homelab memos")
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
|
|
@ -552,6 +841,7 @@ func TestMemoFilterComplexLogical(t *testing.T) {
|
|||
// =============================================================================
|
||||
|
||||
func TestMemoFilterMultipleFilters(t *testing.T) {
|
||||
t.Parallel()
|
||||
tc := NewMemoFilterTestContext(t)
|
||||
defer tc.Close()
|
||||
|
||||
|
|
@ -570,6 +860,7 @@ func TestMemoFilterMultipleFilters(t *testing.T) {
|
|||
// =============================================================================
|
||||
|
||||
func TestMemoFilterNullPayload(t *testing.T) {
|
||||
t.Parallel()
|
||||
tc := NewMemoFilterTestContext(t)
|
||||
defer tc.Close()
|
||||
|
||||
|
|
@ -581,6 +872,7 @@ func TestMemoFilterNullPayload(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestMemoFilterNoMatches(t *testing.T) {
|
||||
t.Parallel()
|
||||
tc := NewMemoFilterTestContext(t)
|
||||
defer tc.Close()
|
||||
|
||||
|
|
@ -589,3 +881,48 @@ func TestMemoFilterNoMatches(t *testing.T) {
|
|||
memos := tc.ListWithFilter(`content.contains("nonexistent12345")`)
|
||||
require.Len(t, memos, 0)
|
||||
}
|
||||
|
||||
func TestMemoFilterJSONBooleanLogic(t *testing.T) {
|
||||
t.Parallel()
|
||||
tc := NewMemoFilterTestContext(t)
|
||||
defer tc.Close()
|
||||
|
||||
// 1. Memo with task list (true) and NO link (null)
|
||||
tc.CreateMemo(NewMemoBuilder("memo-task-only", tc.User.ID).
|
||||
Content("Task only").
|
||||
Property(func(p *storepb.MemoPayload_Property) { p.HasTaskList = true }))
|
||||
|
||||
// 2. Memo with link (true) and NO task list (null)
|
||||
tc.CreateMemo(NewMemoBuilder("memo-link-only", tc.User.ID).
|
||||
Content("Link only").
|
||||
Property(func(p *storepb.MemoPayload_Property) { p.HasLink = true }))
|
||||
|
||||
// 3. Memo with both (true)
|
||||
tc.CreateMemo(NewMemoBuilder("memo-both", tc.User.ID).
|
||||
Content("Both").
|
||||
Property(func(p *storepb.MemoPayload_Property) {
|
||||
p.HasTaskList = true
|
||||
p.HasLink = true
|
||||
}))
|
||||
|
||||
// 4. Memo with neither (null)
|
||||
tc.CreateMemo(NewMemoBuilder("memo-neither", tc.User.ID).Content("Neither"))
|
||||
|
||||
// Test A: has_task_list || has_link
|
||||
// Expected: 3 memos (task-only, link-only, both). Neither should be excluded.
|
||||
// This specifically tests the NULL handling in OR logic (NULL || TRUE should be TRUE)
|
||||
memos := tc.ListWithFilter(`has_task_list || has_link`)
|
||||
require.Len(t, memos, 3, "Should find 3 memos with OR logic")
|
||||
|
||||
// Test B: !has_task_list
|
||||
// Expected: 2 memos (link-only, neither). Memos where has_task_list is NULL or FALSE.
|
||||
// Note: If NULL is not handled, !NULL is still NULL (false-y in WHERE), so "neither" might be missed depending on logic.
|
||||
// In our implementation, we want missing fields to behave as false.
|
||||
memos = tc.ListWithFilter(`!has_task_list`)
|
||||
require.Len(t, memos, 2, "Should find 2 memos where task list is false or missing")
|
||||
|
||||
// Test C: has_task_list && !has_link
|
||||
// Expected: 1 memo (task-only).
|
||||
memos = tc.ListWithFilter(`has_task_list && !has_link`)
|
||||
require.Len(t, memos, 1, "Should find 1 memo (task only)")
|
||||
}
|
||||
|
|
|
|||
|
|
@ -10,6 +10,7 @@ import (
|
|||
)
|
||||
|
||||
func TestMemoRelationStore(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := context.Background()
|
||||
ts := NewTestingStore(ctx, t)
|
||||
user, err := createTestingHostUser(ctx, ts)
|
||||
|
|
@ -62,6 +63,7 @@ func TestMemoRelationStore(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestMemoRelationListByMemoID(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := context.Background()
|
||||
ts := NewTestingStore(ctx, t)
|
||||
user, err := createTestingHostUser(ctx, ts)
|
||||
|
|
@ -136,6 +138,7 @@ func TestMemoRelationListByMemoID(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestMemoRelationDelete(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := context.Background()
|
||||
ts := NewTestingStore(ctx, t)
|
||||
user, err := createTestingHostUser(ctx, ts)
|
||||
|
|
@ -193,6 +196,7 @@ func TestMemoRelationDelete(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestMemoRelationDifferentTypes(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := context.Background()
|
||||
ts := NewTestingStore(ctx, t)
|
||||
user, err := createTestingHostUser(ctx, ts)
|
||||
|
|
@ -239,3 +243,440 @@ func TestMemoRelationDifferentTypes(t *testing.T) {
|
|||
|
||||
ts.Close()
|
||||
}
|
||||
|
||||
func TestMemoRelationUpsertSameRelation(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := context.Background()
|
||||
ts := NewTestingStore(ctx, t)
|
||||
user, err := createTestingHostUser(ctx, ts)
|
||||
require.NoError(t, err)
|
||||
|
||||
mainMemo, err := ts.CreateMemo(ctx, &store.Memo{
|
||||
UID: "main-memo",
|
||||
CreatorID: user.ID,
|
||||
Content: "main memo content",
|
||||
Visibility: store.Public,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
relatedMemo, err := ts.CreateMemo(ctx, &store.Memo{
|
||||
UID: "related-memo",
|
||||
CreatorID: user.ID,
|
||||
Content: "related memo content",
|
||||
Visibility: store.Public,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create relation
|
||||
_, err = ts.UpsertMemoRelation(ctx, &store.MemoRelation{
|
||||
MemoID: mainMemo.ID,
|
||||
RelatedMemoID: relatedMemo.ID,
|
||||
Type: store.MemoRelationReference,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Upsert the same relation again (should not create duplicate)
|
||||
_, err = ts.UpsertMemoRelation(ctx, &store.MemoRelation{
|
||||
MemoID: mainMemo.ID,
|
||||
RelatedMemoID: relatedMemo.ID,
|
||||
Type: store.MemoRelationReference,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify only one relation exists
|
||||
relations, err := ts.ListMemoRelations(ctx, &store.FindMemoRelation{
|
||||
MemoID: &mainMemo.ID,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Len(t, relations, 1)
|
||||
|
||||
ts.Close()
|
||||
}
|
||||
|
||||
func TestMemoRelationDeleteByType(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := context.Background()
|
||||
ts := NewTestingStore(ctx, t)
|
||||
user, err := createTestingHostUser(ctx, ts)
|
||||
require.NoError(t, err)
|
||||
|
||||
mainMemo, err := ts.CreateMemo(ctx, &store.Memo{
|
||||
UID: "main-memo",
|
||||
CreatorID: user.ID,
|
||||
Content: "main memo content",
|
||||
Visibility: store.Public,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
relatedMemo1, err := ts.CreateMemo(ctx, &store.Memo{
|
||||
UID: "related-memo-1",
|
||||
CreatorID: user.ID,
|
||||
Content: "related memo 1 content",
|
||||
Visibility: store.Public,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
relatedMemo2, err := ts.CreateMemo(ctx, &store.Memo{
|
||||
UID: "related-memo-2",
|
||||
CreatorID: user.ID,
|
||||
Content: "related memo 2 content",
|
||||
Visibility: store.Public,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create reference relations
|
||||
_, err = ts.UpsertMemoRelation(ctx, &store.MemoRelation{
|
||||
MemoID: mainMemo.ID,
|
||||
RelatedMemoID: relatedMemo1.ID,
|
||||
Type: store.MemoRelationReference,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create comment relation
|
||||
_, err = ts.UpsertMemoRelation(ctx, &store.MemoRelation{
|
||||
MemoID: mainMemo.ID,
|
||||
RelatedMemoID: relatedMemo2.ID,
|
||||
Type: store.MemoRelationComment,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Delete only reference type relations
|
||||
refType := store.MemoRelationReference
|
||||
err = ts.DeleteMemoRelation(ctx, &store.DeleteMemoRelation{
|
||||
MemoID: &mainMemo.ID,
|
||||
Type: &refType,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify only comment relation remains
|
||||
relations, err := ts.ListMemoRelations(ctx, &store.FindMemoRelation{
|
||||
MemoID: &mainMemo.ID,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Len(t, relations, 1)
|
||||
require.Equal(t, store.MemoRelationComment, relations[0].Type)
|
||||
|
||||
ts.Close()
|
||||
}
|
||||
|
||||
func TestMemoRelationDeleteByMemoID(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := context.Background()
|
||||
ts := NewTestingStore(ctx, t)
|
||||
user, err := createTestingHostUser(ctx, ts)
|
||||
require.NoError(t, err)
|
||||
|
||||
memo1, err := ts.CreateMemo(ctx, &store.Memo{
|
||||
UID: "memo-1",
|
||||
CreatorID: user.ID,
|
||||
Content: "memo 1 content",
|
||||
Visibility: store.Public,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
memo2, err := ts.CreateMemo(ctx, &store.Memo{
|
||||
UID: "memo-2",
|
||||
CreatorID: user.ID,
|
||||
Content: "memo 2 content",
|
||||
Visibility: store.Public,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
relatedMemo, err := ts.CreateMemo(ctx, &store.Memo{
|
||||
UID: "related-memo",
|
||||
CreatorID: user.ID,
|
||||
Content: "related memo content",
|
||||
Visibility: store.Public,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create relations for both memos
|
||||
_, err = ts.UpsertMemoRelation(ctx, &store.MemoRelation{
|
||||
MemoID: memo1.ID,
|
||||
RelatedMemoID: relatedMemo.ID,
|
||||
Type: store.MemoRelationReference,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = ts.UpsertMemoRelation(ctx, &store.MemoRelation{
|
||||
MemoID: memo2.ID,
|
||||
RelatedMemoID: relatedMemo.ID,
|
||||
Type: store.MemoRelationReference,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Delete all relations for memo1
|
||||
err = ts.DeleteMemoRelation(ctx, &store.DeleteMemoRelation{
|
||||
MemoID: &memo1.ID,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify memo1's relations are gone
|
||||
relations, err := ts.ListMemoRelations(ctx, &store.FindMemoRelation{
|
||||
MemoID: &memo1.ID,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Len(t, relations, 0)
|
||||
|
||||
// Verify memo2's relations still exist
|
||||
relations, err = ts.ListMemoRelations(ctx, &store.FindMemoRelation{
|
||||
MemoID: &memo2.ID,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Len(t, relations, 1)
|
||||
|
||||
ts.Close()
|
||||
}
|
||||
|
||||
func TestMemoRelationListByRelatedMemoID(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := context.Background()
|
||||
ts := NewTestingStore(ctx, t)
|
||||
user, err := createTestingHostUser(ctx, ts)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create a memo that will be referenced by others
|
||||
targetMemo, err := ts.CreateMemo(ctx, &store.Memo{
|
||||
UID: "target-memo",
|
||||
CreatorID: user.ID,
|
||||
Content: "target memo content",
|
||||
Visibility: store.Public,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create memos that reference the target
|
||||
referrer1, err := ts.CreateMemo(ctx, &store.Memo{
|
||||
UID: "referrer-1",
|
||||
CreatorID: user.ID,
|
||||
Content: "referrer 1 content",
|
||||
Visibility: store.Public,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
referrer2, err := ts.CreateMemo(ctx, &store.Memo{
|
||||
UID: "referrer-2",
|
||||
CreatorID: user.ID,
|
||||
Content: "referrer 2 content",
|
||||
Visibility: store.Public,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create relations pointing to target
|
||||
_, err = ts.UpsertMemoRelation(ctx, &store.MemoRelation{
|
||||
MemoID: referrer1.ID,
|
||||
RelatedMemoID: targetMemo.ID,
|
||||
Type: store.MemoRelationReference,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = ts.UpsertMemoRelation(ctx, &store.MemoRelation{
|
||||
MemoID: referrer2.ID,
|
||||
RelatedMemoID: targetMemo.ID,
|
||||
Type: store.MemoRelationComment,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// List by related memo ID (find all memos that reference the target)
|
||||
relations, err := ts.ListMemoRelations(ctx, &store.FindMemoRelation{
|
||||
RelatedMemoID: &targetMemo.ID,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Len(t, relations, 2)
|
||||
|
||||
ts.Close()
|
||||
}
|
||||
|
||||
func TestMemoRelationListCombinedFilters(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := context.Background()
|
||||
ts := NewTestingStore(ctx, t)
|
||||
user, err := createTestingHostUser(ctx, ts)
|
||||
require.NoError(t, err)
|
||||
|
||||
mainMemo, err := ts.CreateMemo(ctx, &store.Memo{
|
||||
UID: "main-memo",
|
||||
CreatorID: user.ID,
|
||||
Content: "main memo content",
|
||||
Visibility: store.Public,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
relatedMemo1, err := ts.CreateMemo(ctx, &store.Memo{
|
||||
UID: "related-memo-1",
|
||||
CreatorID: user.ID,
|
||||
Content: "related memo 1 content",
|
||||
Visibility: store.Public,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
relatedMemo2, err := ts.CreateMemo(ctx, &store.Memo{
|
||||
UID: "related-memo-2",
|
||||
CreatorID: user.ID,
|
||||
Content: "related memo 2 content",
|
||||
Visibility: store.Public,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create multiple relations
|
||||
_, err = ts.UpsertMemoRelation(ctx, &store.MemoRelation{
|
||||
MemoID: mainMemo.ID,
|
||||
RelatedMemoID: relatedMemo1.ID,
|
||||
Type: store.MemoRelationReference,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = ts.UpsertMemoRelation(ctx, &store.MemoRelation{
|
||||
MemoID: mainMemo.ID,
|
||||
RelatedMemoID: relatedMemo2.ID,
|
||||
Type: store.MemoRelationComment,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// List with MemoID and Type filter
|
||||
refType := store.MemoRelationReference
|
||||
relations, err := ts.ListMemoRelations(ctx, &store.FindMemoRelation{
|
||||
MemoID: &mainMemo.ID,
|
||||
Type: &refType,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Len(t, relations, 1)
|
||||
require.Equal(t, relatedMemo1.ID, relations[0].RelatedMemoID)
|
||||
|
||||
// List with MemoID, RelatedMemoID, and Type filter
|
||||
commentType := store.MemoRelationComment
|
||||
relations, err = ts.ListMemoRelations(ctx, &store.FindMemoRelation{
|
||||
MemoID: &mainMemo.ID,
|
||||
RelatedMemoID: &relatedMemo2.ID,
|
||||
Type: &commentType,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Len(t, relations, 1)
|
||||
|
||||
ts.Close()
|
||||
}
|
||||
|
||||
func TestMemoRelationListEmpty(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := context.Background()
|
||||
ts := NewTestingStore(ctx, t)
|
||||
user, err := createTestingHostUser(ctx, ts)
|
||||
require.NoError(t, err)
|
||||
|
||||
memo, err := ts.CreateMemo(ctx, &store.Memo{
|
||||
UID: "memo-no-relations",
|
||||
CreatorID: user.ID,
|
||||
Content: "memo with no relations",
|
||||
Visibility: store.Public,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// List relations for memo with none
|
||||
relations, err := ts.ListMemoRelations(ctx, &store.FindMemoRelation{
|
||||
MemoID: &memo.ID,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Len(t, relations, 0)
|
||||
|
||||
ts.Close()
|
||||
}
|
||||
|
||||
func TestMemoRelationBidirectional(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := context.Background()
|
||||
ts := NewTestingStore(ctx, t)
|
||||
user, err := createTestingHostUser(ctx, ts)
|
||||
require.NoError(t, err)
|
||||
|
||||
memoA, err := ts.CreateMemo(ctx, &store.Memo{
|
||||
UID: "memo-a",
|
||||
CreatorID: user.ID,
|
||||
Content: "memo A content",
|
||||
Visibility: store.Public,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
memoB, err := ts.CreateMemo(ctx, &store.Memo{
|
||||
UID: "memo-b",
|
||||
CreatorID: user.ID,
|
||||
Content: "memo B content",
|
||||
Visibility: store.Public,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create relation A -> B
|
||||
_, err = ts.UpsertMemoRelation(ctx, &store.MemoRelation{
|
||||
MemoID: memoA.ID,
|
||||
RelatedMemoID: memoB.ID,
|
||||
Type: store.MemoRelationReference,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create relation B -> A (reverse direction)
|
||||
_, err = ts.UpsertMemoRelation(ctx, &store.MemoRelation{
|
||||
MemoID: memoB.ID,
|
||||
RelatedMemoID: memoA.ID,
|
||||
Type: store.MemoRelationReference,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify A -> B exists
|
||||
relationsFromA, err := ts.ListMemoRelations(ctx, &store.FindMemoRelation{
|
||||
MemoID: &memoA.ID,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Len(t, relationsFromA, 1)
|
||||
require.Equal(t, memoB.ID, relationsFromA[0].RelatedMemoID)
|
||||
|
||||
// Verify B -> A exists
|
||||
relationsFromB, err := ts.ListMemoRelations(ctx, &store.FindMemoRelation{
|
||||
MemoID: &memoB.ID,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Len(t, relationsFromB, 1)
|
||||
require.Equal(t, memoA.ID, relationsFromB[0].RelatedMemoID)
|
||||
|
||||
ts.Close()
|
||||
}
|
||||
|
||||
func TestMemoRelationMultipleRelationsToSameMemo(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := context.Background()
|
||||
ts := NewTestingStore(ctx, t)
|
||||
user, err := createTestingHostUser(ctx, ts)
|
||||
require.NoError(t, err)
|
||||
|
||||
mainMemo, err := ts.CreateMemo(ctx, &store.Memo{
|
||||
UID: "main-memo",
|
||||
CreatorID: user.ID,
|
||||
Content: "main memo content",
|
||||
Visibility: store.Public,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create multiple memos that all relate to the main memo
|
||||
for i := 1; i <= 5; i++ {
|
||||
relatedMemo, err := ts.CreateMemo(ctx, &store.Memo{
|
||||
UID: "related-memo-" + string(rune('0'+i)),
|
||||
CreatorID: user.ID,
|
||||
Content: "related memo content",
|
||||
Visibility: store.Public,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = ts.UpsertMemoRelation(ctx, &store.MemoRelation{
|
||||
MemoID: mainMemo.ID,
|
||||
RelatedMemoID: relatedMemo.ID,
|
||||
Type: store.MemoRelationReference,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
// Verify all 5 relations exist
|
||||
relations, err := ts.ListMemoRelations(ctx, &store.FindMemoRelation{
|
||||
MemoID: &mainMemo.ID,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Len(t, relations, 5)
|
||||
|
||||
ts.Close()
|
||||
}
|
||||
|
|
|
|||
|
|
@ -13,6 +13,7 @@ import (
|
|||
)
|
||||
|
||||
func TestMemoStore(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := context.Background()
|
||||
ts := NewTestingStore(ctx, t)
|
||||
user, err := createTestingHostUser(ctx, ts)
|
||||
|
|
@ -64,6 +65,7 @@ func TestMemoStore(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestMemoListByTags(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := context.Background()
|
||||
ts := NewTestingStore(ctx, t)
|
||||
user, err := createTestingHostUser(ctx, ts)
|
||||
|
|
@ -96,6 +98,7 @@ func TestMemoListByTags(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestDeleteMemoStore(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := context.Background()
|
||||
ts := NewTestingStore(ctx, t)
|
||||
user, err := createTestingHostUser(ctx, ts)
|
||||
|
|
@ -117,6 +120,7 @@ func TestDeleteMemoStore(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestMemoGetByID(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := context.Background()
|
||||
ts := NewTestingStore(ctx, t)
|
||||
user, err := createTestingHostUser(ctx, ts)
|
||||
|
|
@ -147,6 +151,7 @@ func TestMemoGetByID(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestMemoGetByUID(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := context.Background()
|
||||
ts := NewTestingStore(ctx, t)
|
||||
user, err := createTestingHostUser(ctx, ts)
|
||||
|
|
@ -177,6 +182,7 @@ func TestMemoGetByUID(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestMemoListByVisibility(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := context.Background()
|
||||
ts := NewTestingStore(ctx, t)
|
||||
user, err := createTestingHostUser(ctx, ts)
|
||||
|
|
@ -239,6 +245,7 @@ func TestMemoListByVisibility(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestMemoListWithPagination(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := context.Background()
|
||||
ts := NewTestingStore(ctx, t)
|
||||
user, err := createTestingHostUser(ctx, ts)
|
||||
|
|
@ -274,6 +281,7 @@ func TestMemoListWithPagination(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestMemoUpdatePinned(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := context.Background()
|
||||
ts := NewTestingStore(ctx, t)
|
||||
user, err := createTestingHostUser(ctx, ts)
|
||||
|
|
@ -317,6 +325,7 @@ func TestMemoUpdatePinned(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestMemoUpdateVisibility(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := context.Background()
|
||||
ts := NewTestingStore(ctx, t)
|
||||
user, err := createTestingHostUser(ctx, ts)
|
||||
|
|
@ -359,6 +368,7 @@ func TestMemoUpdateVisibility(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestMemoInvalidUID(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := context.Background()
|
||||
ts := NewTestingStore(ctx, t)
|
||||
user, err := createTestingHostUser(ctx, ts)
|
||||
|
|
@ -378,6 +388,7 @@ func TestMemoInvalidUID(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestMemoWithPayload(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := context.Background()
|
||||
ts := NewTestingStore(ctx, t)
|
||||
user, err := createTestingHostUser(ctx, ts)
|
||||
|
|
|
|||
|
|
@ -2,16 +2,199 @@ package test
|
|||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/usememos/memos/store"
|
||||
)
|
||||
|
||||
func TestGetCurrentSchemaVersion(t *testing.T) {
|
||||
// TestFreshInstall verifies that LATEST.sql applies correctly on a fresh database.
|
||||
// This is essentially what NewTestingStore already does, but we make it explicit.
|
||||
func TestFreshInstall(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := context.Background()
|
||||
|
||||
// NewTestingStore creates a fresh database and runs Migrate()
|
||||
// which applies LATEST.sql for uninitialized databases
|
||||
ts := NewTestingStore(ctx, t)
|
||||
|
||||
// Verify migration completed successfully
|
||||
currentSchemaVersion, err := ts.GetCurrentSchemaVersion()
|
||||
require.NoError(t, err)
|
||||
require.NotEmpty(t, currentSchemaVersion, "schema version should be set after fresh install")
|
||||
|
||||
// Verify we can read instance settings (basic sanity check)
|
||||
instanceSetting, err := ts.GetInstanceBasicSetting(ctx)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, currentSchemaVersion, instanceSetting.SchemaVersion)
|
||||
}
|
||||
|
||||
// TestMigrationReRun verifies that re-running the migration on an already
|
||||
// migrated database does not fail or cause issues. This simulates a
|
||||
// scenario where the server is restarted.
|
||||
func TestMigrationReRun(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := context.Background()
|
||||
// Use the shared testing store which already runs migrations on init
|
||||
ts := NewTestingStore(ctx, t)
|
||||
|
||||
// Get current version
|
||||
initialVersion, err := ts.GetCurrentSchemaVersion()
|
||||
require.NoError(t, err)
|
||||
|
||||
// Manually trigger migration again
|
||||
err = ts.Migrate(ctx)
|
||||
require.NoError(t, err, "re-running migration should not fail")
|
||||
|
||||
// Verify version hasn't changed (or at least is valid)
|
||||
finalVersion, err := ts.GetCurrentSchemaVersion()
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, initialVersion, finalVersion, "version should match after re-run")
|
||||
}
|
||||
|
||||
// TestMigrationWithData verifies that migration preserves data integrity.
|
||||
// Creates data, then re-runs migration and verifies data is still accessible.
|
||||
func TestMigrationWithData(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := context.Background()
|
||||
ts := NewTestingStore(ctx, t)
|
||||
|
||||
currentSchemaVersion, err := ts.GetCurrentSchemaVersion()
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "0.25.1", currentSchemaVersion)
|
||||
// Create a user and memo before re-running migration
|
||||
user, err := createTestingHostUser(ctx, ts)
|
||||
require.NoError(t, err, "should create user")
|
||||
|
||||
originalMemo, err := ts.CreateMemo(ctx, &store.Memo{
|
||||
UID: "migration-data-test",
|
||||
CreatorID: user.ID,
|
||||
Content: "Data before migration re-run",
|
||||
Visibility: store.Public,
|
||||
})
|
||||
require.NoError(t, err, "should create memo")
|
||||
|
||||
// Re-run migration
|
||||
err = ts.Migrate(ctx)
|
||||
require.NoError(t, err, "re-running migration should not fail")
|
||||
|
||||
// Verify data is still accessible
|
||||
memo, err := ts.GetMemo(ctx, &store.FindMemo{UID: &originalMemo.UID})
|
||||
require.NoError(t, err, "should retrieve memo after migration")
|
||||
require.Equal(t, "Data before migration re-run", memo.Content, "memo content should be preserved")
|
||||
}
|
||||
|
||||
// TestMigrationMultipleReRuns verifies that migration is idempotent
|
||||
// even when run multiple times in succession.
|
||||
func TestMigrationMultipleReRuns(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := context.Background()
|
||||
ts := NewTestingStore(ctx, t)
|
||||
|
||||
// Get initial version
|
||||
initialVersion, err := ts.GetCurrentSchemaVersion()
|
||||
require.NoError(t, err)
|
||||
|
||||
// Run migration multiple times
|
||||
for i := 0; i < 3; i++ {
|
||||
err = ts.Migrate(ctx)
|
||||
require.NoError(t, err, "migration run %d should not fail", i+1)
|
||||
}
|
||||
|
||||
// Verify version is still correct
|
||||
finalVersion, err := ts.GetCurrentSchemaVersion()
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, initialVersion, finalVersion, "version should remain unchanged after multiple re-runs")
|
||||
}
|
||||
|
||||
// TestMigrationFromStableVersion verifies that upgrading from a stable Memos version
|
||||
// to the current version works correctly. This is the critical upgrade path test.
|
||||
//
|
||||
// Test flow:
|
||||
// 1. Start a stable Memos container to create a database with the old schema
|
||||
// 2. Stop the container and wait for cleanup
|
||||
// 3. Use the store directly to run migration with current code
|
||||
// 4. Verify the migration succeeded and data can be written
|
||||
//
|
||||
// Note: This test is skipped when running with -race flag because testcontainers
|
||||
// has known race conditions in its reaper code that are outside our control.
|
||||
func TestMigrationFromStableVersion(t *testing.T) {
|
||||
// Skip for non-SQLite drivers (simplifies the test)
|
||||
if getDriverFromEnv() != "sqlite" {
|
||||
t.Skip("skipping upgrade test for non-sqlite driver")
|
||||
}
|
||||
|
||||
// Skip if explicitly disabled (e.g., in environments without Docker)
|
||||
if os.Getenv("SKIP_CONTAINER_TESTS") == "1" {
|
||||
t.Skip("skipping container-based test (SKIP_CONTAINER_TESTS=1)")
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
dataDir := t.TempDir()
|
||||
|
||||
// 1. Start stable Memos container to create database with old schema
|
||||
cfg := MemosContainerConfig{
|
||||
Driver: "sqlite",
|
||||
DataDir: dataDir,
|
||||
Version: StableMemosVersion,
|
||||
}
|
||||
|
||||
t.Logf("Starting Memos %s container to create old-schema database...", cfg.Version)
|
||||
container, err := StartMemosContainer(ctx, cfg)
|
||||
require.NoError(t, err, "failed to start stable memos container")
|
||||
|
||||
// Wait for the container to fully initialize the database
|
||||
time.Sleep(10 * time.Second)
|
||||
|
||||
// Stop the container gracefully
|
||||
t.Log("Stopping stable Memos container...")
|
||||
err = container.Terminate(ctx)
|
||||
require.NoError(t, err, "failed to stop memos container")
|
||||
|
||||
// Wait for file handles to be released
|
||||
time.Sleep(2 * time.Second)
|
||||
|
||||
// 2. Connect to the database directly and run migration with current code
|
||||
dsn := fmt.Sprintf("%s/memos_prod.db", dataDir)
|
||||
t.Logf("Connecting to database at %s...", dsn)
|
||||
|
||||
ts := NewTestingStoreWithDSN(ctx, t, "sqlite", dsn)
|
||||
|
||||
// Get the schema version before migration
|
||||
oldSetting, err := ts.GetInstanceBasicSetting(ctx)
|
||||
require.NoError(t, err)
|
||||
t.Logf("Old schema version: %s", oldSetting.SchemaVersion)
|
||||
|
||||
// 3. Run migration with current code
|
||||
t.Log("Running migration with current code...")
|
||||
err = ts.Migrate(ctx)
|
||||
require.NoError(t, err, "migration from stable version should succeed")
|
||||
|
||||
// 4. Verify migration succeeded
|
||||
newVersion, err := ts.GetCurrentSchemaVersion()
|
||||
require.NoError(t, err)
|
||||
t.Logf("New schema version: %s", newVersion)
|
||||
|
||||
newSetting, err := ts.GetInstanceBasicSetting(ctx)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, newVersion, newSetting.SchemaVersion, "schema version should be updated")
|
||||
|
||||
// Verify we can write data to the migrated database
|
||||
user, err := createTestingHostUser(ctx, ts)
|
||||
require.NoError(t, err, "should create user after migration")
|
||||
|
||||
memo, err := ts.CreateMemo(ctx, &store.Memo{
|
||||
UID: "post-upgrade-memo",
|
||||
CreatorID: user.ID,
|
||||
Content: "Content after upgrade from stable",
|
||||
Visibility: store.Public,
|
||||
})
|
||||
require.NoError(t, err, "should create memo after migration")
|
||||
require.Equal(t, "Content after upgrade from stable", memo.Content)
|
||||
|
||||
t.Logf("Migration successful: %s -> %s", oldSetting.SchemaVersion, newVersion)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -10,6 +10,7 @@ import (
|
|||
)
|
||||
|
||||
func TestReactionStore(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := context.Background()
|
||||
ts := NewTestingStore(ctx, t)
|
||||
|
||||
|
|
@ -67,6 +68,7 @@ func TestReactionStore(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestReactionListByCreatorID(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := context.Background()
|
||||
ts := NewTestingStore(ctx, t)
|
||||
|
||||
|
|
@ -113,6 +115,7 @@ func TestReactionListByCreatorID(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestReactionMultipleContentIDs(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := context.Background()
|
||||
ts := NewTestingStore(ctx, t)
|
||||
|
||||
|
|
@ -148,6 +151,7 @@ func TestReactionMultipleContentIDs(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestReactionUpsertDifferentTypes(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := context.Background()
|
||||
ts := NewTestingStore(ctx, t)
|
||||
|
||||
|
|
|
|||
|
|
@ -37,6 +37,27 @@ func NewTestingStore(ctx context.Context, t *testing.T) *store.Store {
|
|||
return store
|
||||
}
|
||||
|
||||
// NewTestingStoreWithDSN creates a testing store connected to a specific DSN.
|
||||
// This is useful for testing migrations on existing data.
|
||||
func NewTestingStoreWithDSN(_ context.Context, t *testing.T, driver, dsn string) *store.Store {
|
||||
profile := &profile.Profile{
|
||||
Port: getUnusedPort(),
|
||||
Data: t.TempDir(), // Dummy dir, DSN matters
|
||||
DSN: dsn,
|
||||
Driver: driver,
|
||||
Version: version.GetCurrentVersion(),
|
||||
}
|
||||
dbDriver, err := db.NewDBDriver(profile)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create db driver: %v", err)
|
||||
}
|
||||
|
||||
store := store.New(dbDriver, profile)
|
||||
// Do not run Migrate() automatically, as we might be testing pre-migration state
|
||||
// or want to run it manually.
|
||||
return store
|
||||
}
|
||||
|
||||
func getUnusedPort() int {
|
||||
// Get a random unused port
|
||||
listener, err := net.Listen("tcp", "localhost:0")
|
||||
|
|
@ -73,12 +94,11 @@ func getTestingProfileForDriver(t *testing.T, driver string) *profile.Profile {
|
|||
}
|
||||
|
||||
return &profile.Profile{
|
||||
Mode: mode,
|
||||
Port: port,
|
||||
Data: dir,
|
||||
DSN: dsn,
|
||||
Driver: driver,
|
||||
Version: version.GetCurrentVersion(mode),
|
||||
Version: version.GetCurrentVersion(),
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -2,15 +2,18 @@ package test
|
|||
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
"google.golang.org/protobuf/types/known/timestamppb"
|
||||
|
||||
storepb "github.com/usememos/memos/proto/gen/store"
|
||||
"github.com/usememos/memos/store"
|
||||
)
|
||||
|
||||
func TestUserSettingStore(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := context.Background()
|
||||
ts := NewTestingStore(ctx, t)
|
||||
user, err := createTestingHostUser(ctx, ts)
|
||||
|
|
@ -28,6 +31,7 @@ func TestUserSettingStore(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestUserSettingGetByUserID(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := context.Background()
|
||||
ts := NewTestingStore(ctx, t)
|
||||
user, err := createTestingHostUser(ctx, ts)
|
||||
|
|
@ -62,6 +66,7 @@ func TestUserSettingGetByUserID(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestUserSettingUpsertUpdate(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := context.Background()
|
||||
ts := NewTestingStore(ctx, t)
|
||||
user, err := createTestingHostUser(ctx, ts)
|
||||
|
|
@ -100,6 +105,7 @@ func TestUserSettingUpsertUpdate(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestUserSettingRefreshTokens(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := context.Background()
|
||||
ts := NewTestingStore(ctx, t)
|
||||
user, err := createTestingHostUser(ctx, ts)
|
||||
|
|
@ -160,6 +166,7 @@ func TestUserSettingRefreshTokens(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestUserSettingPersonalAccessTokens(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := context.Background()
|
||||
ts := NewTestingStore(ctx, t)
|
||||
user, err := createTestingHostUser(ctx, ts)
|
||||
|
|
@ -211,6 +218,7 @@ func TestUserSettingPersonalAccessTokens(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestUserSettingWebhooks(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := context.Background()
|
||||
ts := NewTestingStore(ctx, t)
|
||||
user, err := createTestingHostUser(ctx, ts)
|
||||
|
|
@ -277,6 +285,7 @@ func TestUserSettingWebhooks(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestUserSettingShortcuts(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := context.Background()
|
||||
ts := NewTestingStore(ctx, t)
|
||||
user, err := createTestingHostUser(ctx, ts)
|
||||
|
|
@ -308,3 +317,559 @@ func TestUserSettingShortcuts(t *testing.T) {
|
|||
|
||||
ts.Close()
|
||||
}
|
||||
|
||||
func TestUserSettingGetUserByPATHash(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := context.Background()
|
||||
ts := NewTestingStore(ctx, t)
|
||||
user, err := createTestingHostUser(ctx, ts)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create a PAT with a known hash
|
||||
patHash := "test-pat-hash-12345"
|
||||
pat := &storepb.PersonalAccessTokensUserSetting_PersonalAccessToken{
|
||||
TokenId: "pat-test-1",
|
||||
TokenHash: patHash,
|
||||
Description: "Test PAT for lookup",
|
||||
}
|
||||
err = ts.AddUserPersonalAccessToken(ctx, user.ID, pat)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Lookup user by PAT hash
|
||||
result, err := ts.GetUserByPATHash(ctx, patHash)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, result)
|
||||
require.Equal(t, user.ID, result.UserID)
|
||||
require.NotNil(t, result.User)
|
||||
require.Equal(t, user.Username, result.User.Username)
|
||||
require.NotNil(t, result.PAT)
|
||||
require.Equal(t, "pat-test-1", result.PAT.TokenId)
|
||||
require.Equal(t, "Test PAT for lookup", result.PAT.Description)
|
||||
|
||||
ts.Close()
|
||||
}
|
||||
|
||||
func TestUserSettingGetUserByPATHashNotFound(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := context.Background()
|
||||
ts := NewTestingStore(ctx, t)
|
||||
_, err := createTestingHostUser(ctx, ts)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Lookup non-existent PAT hash
|
||||
result, err := ts.GetUserByPATHash(ctx, "non-existent-hash")
|
||||
require.Error(t, err)
|
||||
require.Nil(t, result)
|
||||
|
||||
ts.Close()
|
||||
}
|
||||
|
||||
func TestUserSettingGetUserByPATHashMultipleUsers(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := context.Background()
|
||||
ts := NewTestingStore(ctx, t)
|
||||
user1, err := createTestingHostUser(ctx, ts)
|
||||
require.NoError(t, err)
|
||||
user2, err := createTestingUserWithRole(ctx, ts, "user2", store.RoleUser)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create PATs for both users
|
||||
pat1Hash := "user1-pat-hash"
|
||||
err = ts.AddUserPersonalAccessToken(ctx, user1.ID, &storepb.PersonalAccessTokensUserSetting_PersonalAccessToken{
|
||||
TokenId: "pat-user1",
|
||||
TokenHash: pat1Hash,
|
||||
Description: "User 1 PAT",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
pat2Hash := "user2-pat-hash"
|
||||
err = ts.AddUserPersonalAccessToken(ctx, user2.ID, &storepb.PersonalAccessTokensUserSetting_PersonalAccessToken{
|
||||
TokenId: "pat-user2",
|
||||
TokenHash: pat2Hash,
|
||||
Description: "User 2 PAT",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Lookup user1's PAT
|
||||
result1, err := ts.GetUserByPATHash(ctx, pat1Hash)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, user1.ID, result1.UserID)
|
||||
require.Equal(t, user1.Username, result1.User.Username)
|
||||
|
||||
// Lookup user2's PAT
|
||||
result2, err := ts.GetUserByPATHash(ctx, pat2Hash)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, user2.ID, result2.UserID)
|
||||
require.Equal(t, user2.Username, result2.User.Username)
|
||||
|
||||
ts.Close()
|
||||
}
|
||||
|
||||
func TestUserSettingGetUserByPATHashMultiplePATsSameUser(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := context.Background()
|
||||
ts := NewTestingStore(ctx, t)
|
||||
user, err := createTestingHostUser(ctx, ts)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create multiple PATs for the same user
|
||||
pat1Hash := "first-pat-hash"
|
||||
err = ts.AddUserPersonalAccessToken(ctx, user.ID, &storepb.PersonalAccessTokensUserSetting_PersonalAccessToken{
|
||||
TokenId: "pat-1",
|
||||
TokenHash: pat1Hash,
|
||||
Description: "First PAT",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
pat2Hash := "second-pat-hash"
|
||||
err = ts.AddUserPersonalAccessToken(ctx, user.ID, &storepb.PersonalAccessTokensUserSetting_PersonalAccessToken{
|
||||
TokenId: "pat-2",
|
||||
TokenHash: pat2Hash,
|
||||
Description: "Second PAT",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Both PATs should resolve to the same user
|
||||
result1, err := ts.GetUserByPATHash(ctx, pat1Hash)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, user.ID, result1.UserID)
|
||||
require.Equal(t, "pat-1", result1.PAT.TokenId)
|
||||
|
||||
result2, err := ts.GetUserByPATHash(ctx, pat2Hash)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, user.ID, result2.UserID)
|
||||
require.Equal(t, "pat-2", result2.PAT.TokenId)
|
||||
|
||||
ts.Close()
|
||||
}
|
||||
|
||||
func TestUserSettingUpdatePATLastUsed(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := context.Background()
|
||||
ts := NewTestingStore(ctx, t)
|
||||
user, err := createTestingHostUser(ctx, ts)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create a PAT
|
||||
patHash := "pat-hash-for-update"
|
||||
err = ts.AddUserPersonalAccessToken(ctx, user.ID, &storepb.PersonalAccessTokensUserSetting_PersonalAccessToken{
|
||||
TokenId: "pat-update-test",
|
||||
TokenHash: patHash,
|
||||
Description: "PAT for update test",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Update last used timestamp
|
||||
now := timestamppb.Now()
|
||||
err = ts.UpdatePATLastUsed(ctx, user.ID, "pat-update-test", now)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify the update
|
||||
pats, err := ts.GetUserPersonalAccessTokens(ctx, user.ID)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, pats, 1)
|
||||
require.NotNil(t, pats[0].LastUsedAt)
|
||||
|
||||
ts.Close()
|
||||
}
|
||||
|
||||
func TestUserSettingGetUserByPATHashWithExpiredToken(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := context.Background()
|
||||
ts := NewTestingStore(ctx, t)
|
||||
user, err := createTestingHostUser(ctx, ts)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create a PAT with expiration info
|
||||
patHash := "pat-hash-with-expiry"
|
||||
expiresAt := timestamppb.Now()
|
||||
pat := &storepb.PersonalAccessTokensUserSetting_PersonalAccessToken{
|
||||
TokenId: "pat-expiry-test",
|
||||
TokenHash: patHash,
|
||||
Description: "PAT with expiry",
|
||||
ExpiresAt: expiresAt,
|
||||
}
|
||||
err = ts.AddUserPersonalAccessToken(ctx, user.ID, pat)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Should still be able to look up by hash (expiry check is done at auth level)
|
||||
result, err := ts.GetUserByPATHash(ctx, patHash)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, result)
|
||||
require.Equal(t, user.ID, result.UserID)
|
||||
require.NotNil(t, result.PAT.ExpiresAt)
|
||||
|
||||
ts.Close()
|
||||
}
|
||||
|
||||
func TestUserSettingGetUserByPATHashAfterRemoval(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := context.Background()
|
||||
ts := NewTestingStore(ctx, t)
|
||||
user, err := createTestingHostUser(ctx, ts)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create a PAT
|
||||
patHash := "pat-hash-to-remove"
|
||||
err = ts.AddUserPersonalAccessToken(ctx, user.ID, &storepb.PersonalAccessTokensUserSetting_PersonalAccessToken{
|
||||
TokenId: "pat-remove-test",
|
||||
TokenHash: patHash,
|
||||
Description: "PAT to be removed",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify it exists
|
||||
result, err := ts.GetUserByPATHash(ctx, patHash)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, result)
|
||||
|
||||
// Remove the PAT
|
||||
err = ts.RemoveUserPersonalAccessToken(ctx, user.ID, "pat-remove-test")
|
||||
require.NoError(t, err)
|
||||
|
||||
// Should no longer be found
|
||||
result, err = ts.GetUserByPATHash(ctx, patHash)
|
||||
require.Error(t, err)
|
||||
require.Nil(t, result)
|
||||
|
||||
ts.Close()
|
||||
}
|
||||
|
||||
func TestUserSettingGetUserByPATHashSpecialCharacters(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := context.Background()
|
||||
ts := NewTestingStore(ctx, t)
|
||||
user, err := createTestingHostUser(ctx, ts)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create PATs with special characters in hash (simulating real hash values)
|
||||
testCases := []struct {
|
||||
tokenID string
|
||||
tokenHash string
|
||||
}{
|
||||
{"pat-special-1", "abc123+/=XYZ"},
|
||||
{"pat-special-2", "sha256:abcdef1234567890"},
|
||||
{"pat-special-3", "$2a$10$N9qo8uLOickgx2ZMRZoMy"},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
err = ts.AddUserPersonalAccessToken(ctx, user.ID, &storepb.PersonalAccessTokensUserSetting_PersonalAccessToken{
|
||||
TokenId: tc.tokenID,
|
||||
TokenHash: tc.tokenHash,
|
||||
Description: "PAT with special chars",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify lookup works with special characters
|
||||
result, err := ts.GetUserByPATHash(ctx, tc.tokenHash)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, result)
|
||||
require.Equal(t, tc.tokenID, result.PAT.TokenId)
|
||||
}
|
||||
|
||||
ts.Close()
|
||||
}
|
||||
|
||||
func TestUserSettingGetUserByPATHashLargeTokenCount(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := context.Background()
|
||||
ts := NewTestingStore(ctx, t)
|
||||
user, err := createTestingHostUser(ctx, ts)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create many PATs for the same user
|
||||
tokenCount := 10
|
||||
hashes := make([]string, tokenCount)
|
||||
for i := 0; i < tokenCount; i++ {
|
||||
hashes[i] = "pat-hash-" + string(rune('A'+i)) + "-large-test"
|
||||
err = ts.AddUserPersonalAccessToken(ctx, user.ID, &storepb.PersonalAccessTokensUserSetting_PersonalAccessToken{
|
||||
TokenId: "pat-large-" + string(rune('A'+i)),
|
||||
TokenHash: hashes[i],
|
||||
Description: "PAT for large count test",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
// Verify each hash can be looked up
|
||||
for i, hash := range hashes {
|
||||
result, err := ts.GetUserByPATHash(ctx, hash)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, result)
|
||||
require.Equal(t, user.ID, result.UserID)
|
||||
require.Equal(t, "pat-large-"+string(rune('A'+i)), result.PAT.TokenId)
|
||||
}
|
||||
|
||||
ts.Close()
|
||||
}
|
||||
|
||||
func TestUserSettingMultipleSettingTypes(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := context.Background()
|
||||
ts := NewTestingStore(ctx, t)
|
||||
user, err := createTestingHostUser(ctx, ts)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create GENERAL setting
|
||||
_, err = ts.UpsertUserSetting(ctx, &storepb.UserSetting{
|
||||
UserId: user.ID,
|
||||
Key: storepb.UserSetting_GENERAL,
|
||||
Value: &storepb.UserSetting_General{General: &storepb.GeneralUserSetting{Locale: "ja"}},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create SHORTCUTS setting
|
||||
_, err = ts.UpsertUserSetting(ctx, &storepb.UserSetting{
|
||||
UserId: user.ID,
|
||||
Key: storepb.UserSetting_SHORTCUTS,
|
||||
Value: &storepb.UserSetting_Shortcuts{Shortcuts: &storepb.ShortcutsUserSetting{
|
||||
Shortcuts: []*storepb.ShortcutsUserSetting_Shortcut{
|
||||
{Id: "s1", Title: "Shortcut 1"},
|
||||
},
|
||||
}},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Add a PAT
|
||||
err = ts.AddUserPersonalAccessToken(ctx, user.ID, &storepb.PersonalAccessTokensUserSetting_PersonalAccessToken{
|
||||
TokenId: "pat-multi",
|
||||
TokenHash: "hash-multi",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// List all settings for user
|
||||
settings, err := ts.ListUserSettings(ctx, &store.FindUserSetting{UserID: &user.ID})
|
||||
require.NoError(t, err)
|
||||
require.Len(t, settings, 3)
|
||||
|
||||
// Verify each setting type
|
||||
generalSetting, err := ts.GetUserSetting(ctx, &store.FindUserSetting{UserID: &user.ID, Key: storepb.UserSetting_GENERAL})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "ja", generalSetting.GetGeneral().Locale)
|
||||
|
||||
shortcutsSetting, err := ts.GetUserSetting(ctx, &store.FindUserSetting{UserID: &user.ID, Key: storepb.UserSetting_SHORTCUTS})
|
||||
require.NoError(t, err)
|
||||
require.Len(t, shortcutsSetting.GetShortcuts().Shortcuts, 1)
|
||||
|
||||
patsSetting, err := ts.GetUserSetting(ctx, &store.FindUserSetting{UserID: &user.ID, Key: storepb.UserSetting_PERSONAL_ACCESS_TOKENS})
|
||||
require.NoError(t, err)
|
||||
require.Len(t, patsSetting.GetPersonalAccessTokens().Tokens, 1)
|
||||
|
||||
ts.Close()
|
||||
}
|
||||
|
||||
func TestUserSettingShortcutsEdgeCases(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := context.Background()
|
||||
ts := NewTestingStore(ctx, t)
|
||||
user, err := createTestingHostUser(ctx, ts)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Case 1: Special characters in Filter and Title
|
||||
// Includes quotes, backslashes, newlines, and other JSON-sensitive characters
|
||||
specialCharsFilter := `tag in ["work", "project"] && content.contains("urgent")`
|
||||
specialCharsTitle := `Work "Urgent" \ Notes`
|
||||
shortcuts := &storepb.ShortcutsUserSetting{
|
||||
Shortcuts: []*storepb.ShortcutsUserSetting_Shortcut{
|
||||
{Id: "s1", Title: specialCharsTitle, Filter: specialCharsFilter},
|
||||
},
|
||||
}
|
||||
_, err = ts.UpsertUserSetting(ctx, &storepb.UserSetting{
|
||||
UserId: user.ID,
|
||||
Key: storepb.UserSetting_SHORTCUTS,
|
||||
Value: &storepb.UserSetting_Shortcuts{Shortcuts: shortcuts},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
setting, err := ts.GetUserSetting(ctx, &store.FindUserSetting{
|
||||
UserID: &user.ID,
|
||||
Key: storepb.UserSetting_SHORTCUTS,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, setting)
|
||||
require.Len(t, setting.GetShortcuts().Shortcuts, 1)
|
||||
require.Equal(t, specialCharsTitle, setting.GetShortcuts().Shortcuts[0].Title)
|
||||
require.Equal(t, specialCharsFilter, setting.GetShortcuts().Shortcuts[0].Filter)
|
||||
|
||||
// Case 2: Unicode characters
|
||||
unicodeFilter := `tag in ["你好", "世界"]`
|
||||
unicodeTitle := `My 🚀 Shortcuts`
|
||||
shortcuts = &storepb.ShortcutsUserSetting{
|
||||
Shortcuts: []*storepb.ShortcutsUserSetting_Shortcut{
|
||||
{Id: "s2", Title: unicodeTitle, Filter: unicodeFilter},
|
||||
},
|
||||
}
|
||||
_, err = ts.UpsertUserSetting(ctx, &storepb.UserSetting{
|
||||
UserId: user.ID,
|
||||
Key: storepb.UserSetting_SHORTCUTS,
|
||||
Value: &storepb.UserSetting_Shortcuts{Shortcuts: shortcuts},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
setting, err = ts.GetUserSetting(ctx, &store.FindUserSetting{
|
||||
UserID: &user.ID,
|
||||
Key: storepb.UserSetting_SHORTCUTS,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, setting)
|
||||
require.Len(t, setting.GetShortcuts().Shortcuts, 1)
|
||||
require.Equal(t, unicodeTitle, setting.GetShortcuts().Shortcuts[0].Title)
|
||||
require.Equal(t, unicodeFilter, setting.GetShortcuts().Shortcuts[0].Filter)
|
||||
|
||||
// Case 3: Empty shortcuts list
|
||||
// Should allow saving an empty list (clearing shortcuts)
|
||||
shortcuts = &storepb.ShortcutsUserSetting{
|
||||
Shortcuts: []*storepb.ShortcutsUserSetting_Shortcut{},
|
||||
}
|
||||
_, err = ts.UpsertUserSetting(ctx, &storepb.UserSetting{
|
||||
UserId: user.ID,
|
||||
Key: storepb.UserSetting_SHORTCUTS,
|
||||
Value: &storepb.UserSetting_Shortcuts{Shortcuts: shortcuts},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
setting, err = ts.GetUserSetting(ctx, &store.FindUserSetting{
|
||||
UserID: &user.ID,
|
||||
Key: storepb.UserSetting_SHORTCUTS,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, setting)
|
||||
require.NotNil(t, setting.GetShortcuts())
|
||||
require.Len(t, setting.GetShortcuts().Shortcuts, 0)
|
||||
|
||||
// Case 4: Large filter string
|
||||
// Test reasonable large string handling (e.g. 4KB)
|
||||
largeFilter := strings.Repeat("tag:long_tag_name ", 200)
|
||||
shortcuts = &storepb.ShortcutsUserSetting{
|
||||
Shortcuts: []*storepb.ShortcutsUserSetting_Shortcut{
|
||||
{Id: "s3", Title: "Large Filter", Filter: largeFilter},
|
||||
},
|
||||
}
|
||||
_, err = ts.UpsertUserSetting(ctx, &storepb.UserSetting{
|
||||
UserId: user.ID,
|
||||
Key: storepb.UserSetting_SHORTCUTS,
|
||||
Value: &storepb.UserSetting_Shortcuts{Shortcuts: shortcuts},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
setting, err = ts.GetUserSetting(ctx, &store.FindUserSetting{
|
||||
UserID: &user.ID,
|
||||
Key: storepb.UserSetting_SHORTCUTS,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, setting)
|
||||
require.Equal(t, largeFilter, setting.GetShortcuts().Shortcuts[0].Filter)
|
||||
|
||||
ts.Close()
|
||||
}
|
||||
|
||||
func TestUserSettingShortcutsPartialUpdate(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := context.Background()
|
||||
ts := NewTestingStore(ctx, t)
|
||||
user, err := createTestingHostUser(ctx, ts)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Initial set
|
||||
shortcuts := &storepb.ShortcutsUserSetting{
|
||||
Shortcuts: []*storepb.ShortcutsUserSetting_Shortcut{
|
||||
{Id: "s1", Title: "Note 1", Filter: "tag:1"},
|
||||
{Id: "s2", Title: "Note 2", Filter: "tag:2"},
|
||||
},
|
||||
}
|
||||
_, err = ts.UpsertUserSetting(ctx, &storepb.UserSetting{
|
||||
UserId: user.ID,
|
||||
Key: storepb.UserSetting_SHORTCUTS,
|
||||
Value: &storepb.UserSetting_Shortcuts{Shortcuts: shortcuts},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Update by replacing the whole list (Store Upsert replaces the value for the key)
|
||||
// We want to verify that we can "update" a single item by sending the modified list
|
||||
updatedShortcuts := &storepb.ShortcutsUserSetting{
|
||||
Shortcuts: []*storepb.ShortcutsUserSetting_Shortcut{
|
||||
{Id: "s1", Title: "Note 1 Updated", Filter: "tag:1_updated"},
|
||||
{Id: "s2", Title: "Note 2", Filter: "tag:2"},
|
||||
{Id: "s3", Title: "Note 3", Filter: "tag:3"}, // Add new one
|
||||
},
|
||||
}
|
||||
_, err = ts.UpsertUserSetting(ctx, &storepb.UserSetting{
|
||||
UserId: user.ID,
|
||||
Key: storepb.UserSetting_SHORTCUTS,
|
||||
Value: &storepb.UserSetting_Shortcuts{Shortcuts: updatedShortcuts},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
setting, err := ts.GetUserSetting(ctx, &store.FindUserSetting{
|
||||
UserID: &user.ID,
|
||||
Key: storepb.UserSetting_SHORTCUTS,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, setting)
|
||||
require.Len(t, setting.GetShortcuts().Shortcuts, 3)
|
||||
|
||||
// Verify updates
|
||||
for _, s := range setting.GetShortcuts().Shortcuts {
|
||||
if s.Id == "s1" {
|
||||
require.Equal(t, "Note 1 Updated", s.Title)
|
||||
require.Equal(t, "tag:1_updated", s.Filter)
|
||||
} else if s.Id == "s2" {
|
||||
require.Equal(t, "Note 2", s.Title)
|
||||
} else if s.Id == "s3" {
|
||||
require.Equal(t, "Note 3", s.Title)
|
||||
}
|
||||
}
|
||||
|
||||
ts.Close()
|
||||
}
|
||||
|
||||
func TestUserSettingJSONFieldsEdgeCases(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := context.Background()
|
||||
ts := NewTestingStore(ctx, t)
|
||||
user, err := createTestingHostUser(ctx, ts)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Case 1: Webhook with special characters and Unicode in Title and URL
|
||||
specialWebhook := &storepb.WebhooksUserSetting_Webhook{
|
||||
Id: "wh-special",
|
||||
Title: `My "Special" & <Webhook> 🚀`,
|
||||
Url: "https://example.com/hook?query=你好¶m=\"value\"",
|
||||
}
|
||||
err = ts.AddUserWebhook(ctx, user.ID, specialWebhook)
|
||||
require.NoError(t, err)
|
||||
|
||||
webhooks, err := ts.GetUserWebhooks(ctx, user.ID)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, webhooks, 1)
|
||||
require.Equal(t, specialWebhook.Title, webhooks[0].Title)
|
||||
require.Equal(t, specialWebhook.Url, webhooks[0].Url)
|
||||
|
||||
// Case 2: PAT with special description
|
||||
specialPAT := &storepb.PersonalAccessTokensUserSetting_PersonalAccessToken{
|
||||
TokenId: "pat-special",
|
||||
TokenHash: "hash-special",
|
||||
Description: "Token for 'CLI' \n & \"API\" \t with unicode 🔑",
|
||||
}
|
||||
err = ts.AddUserPersonalAccessToken(ctx, user.ID, specialPAT)
|
||||
require.NoError(t, err)
|
||||
|
||||
pats, err := ts.GetUserPersonalAccessTokens(ctx, user.ID)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, pats, 1)
|
||||
require.Equal(t, specialPAT.Description, pats[0].Description)
|
||||
|
||||
// Case 3: Refresh Token with special description
|
||||
specialRefreshToken := &storepb.RefreshTokensUserSetting_RefreshToken{
|
||||
TokenId: "rt-special",
|
||||
Description: "Browser: Firefox (Nightly) / OS: Linux 🐧",
|
||||
}
|
||||
err = ts.AddUserRefreshToken(ctx, user.ID, specialRefreshToken)
|
||||
require.NoError(t, err)
|
||||
|
||||
tokens, err := ts.GetUserRefreshTokens(ctx, user.ID)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, tokens, 1)
|
||||
require.Equal(t, specialRefreshToken.Description, tokens[0].Description)
|
||||
|
||||
ts.Close()
|
||||
}
|
||||
|
|
|
|||
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue