mirror of https://github.com/usememos/memos.git
Refactor code structure and remove redundant changes
This commit is contained in:
commit
446642b2dd
|
|
@ -24,6 +24,8 @@ jobs:
|
|||
outputs:
|
||||
version: ${{ steps.version.outputs.version }}
|
||||
tag: ${{ steps.version.outputs.tag }}
|
||||
major_minor: ${{ steps.version.outputs.major_minor }}
|
||||
is_prerelease: ${{ steps.version.outputs.is_prerelease }}
|
||||
steps:
|
||||
- name: Extract version
|
||||
id: version
|
||||
|
|
@ -34,11 +36,27 @@ jobs:
|
|||
if [ "$EVENT_NAME" = "workflow_dispatch" ]; then
|
||||
echo "tag=" >> "$GITHUB_OUTPUT"
|
||||
echo "version=manual-${GITHUB_SHA::7}" >> "$GITHUB_OUTPUT"
|
||||
echo "major_minor=" >> "$GITHUB_OUTPUT"
|
||||
echo "is_prerelease=false" >> "$GITHUB_OUTPUT"
|
||||
exit 0
|
||||
fi
|
||||
|
||||
if [[ ! "$REF_NAME" =~ ^v([0-9]+\.[0-9]+\.[0-9]+)(-rc\.[0-9]+)?$ ]]; then
|
||||
echo "Unsupported release tag format: $REF_NAME" >&2
|
||||
exit 1
|
||||
fi
|
||||
|
||||
version="${BASH_REMATCH[1]}${BASH_REMATCH[2]}"
|
||||
major_minor="${BASH_REMATCH[1]%.*}"
|
||||
is_prerelease=false
|
||||
if [ -n "${BASH_REMATCH[2]}" ]; then
|
||||
is_prerelease=true
|
||||
fi
|
||||
|
||||
echo "tag=${REF_NAME}" >> "$GITHUB_OUTPUT"
|
||||
echo "version=${REF_NAME#v}" >> "$GITHUB_OUTPUT"
|
||||
echo "version=${version}" >> "$GITHUB_OUTPUT"
|
||||
echo "major_minor=${major_minor}" >> "$GITHUB_OUTPUT"
|
||||
echo "is_prerelease=${is_prerelease}" >> "$GITHUB_OUTPUT"
|
||||
|
||||
build-frontend:
|
||||
name: Build Frontend
|
||||
|
|
@ -226,6 +244,7 @@ jobs:
|
|||
tag_name: ${{ needs.prepare.outputs.tag }}
|
||||
name: ${{ needs.prepare.outputs.tag }}
|
||||
generate_release_notes: true
|
||||
prerelease: ${{ needs.prepare.outputs.is_prerelease == 'true' }}
|
||||
files: artifacts/*
|
||||
|
||||
build-push:
|
||||
|
|
@ -301,7 +320,7 @@ jobs:
|
|||
retention-days: 1
|
||||
|
||||
merge-images:
|
||||
name: Publish Stable Image Tags
|
||||
name: Publish Release Image Tags
|
||||
needs: [prepare, build-push]
|
||||
if: github.event_name != 'workflow_dispatch'
|
||||
runs-on: ubuntu-latest
|
||||
|
|
@ -336,17 +355,28 @@ jobs:
|
|||
working-directory: /tmp/digests
|
||||
run: |
|
||||
version="${{ needs.prepare.outputs.version }}"
|
||||
major_minor=$(echo "$version" | cut -d. -f1,2)
|
||||
if [ "${{ needs.prepare.outputs.is_prerelease }}" = "true" ]; then
|
||||
docker buildx imagetools create \
|
||||
-t "neosmemo/memos:${version}" \
|
||||
-t "ghcr.io/usememos/memos:${version}" \
|
||||
$(printf 'neosmemo/memos@sha256:%s ' *)
|
||||
exit 0
|
||||
fi
|
||||
|
||||
docker buildx imagetools create \
|
||||
-t "neosmemo/memos:${version}" \
|
||||
-t "neosmemo/memos:${major_minor}" \
|
||||
-t "neosmemo/memos:${{ needs.prepare.outputs.major_minor }}" \
|
||||
-t "neosmemo/memos:stable" \
|
||||
-t "ghcr.io/usememos/memos:${version}" \
|
||||
-t "ghcr.io/usememos/memos:${major_minor}" \
|
||||
-t "ghcr.io/usememos/memos:${{ needs.prepare.outputs.major_minor }}" \
|
||||
-t "ghcr.io/usememos/memos:stable" \
|
||||
$(printf 'neosmemo/memos@sha256:%s ' *)
|
||||
|
||||
- name: Inspect images
|
||||
run: |
|
||||
docker buildx imagetools inspect neosmemo/memos:${{ needs.prepare.outputs.version }}
|
||||
if [ "${{ needs.prepare.outputs.is_prerelease }}" = "true" ]; then
|
||||
exit 0
|
||||
fi
|
||||
|
||||
docker buildx imagetools inspect neosmemo/memos:stable
|
||||
|
|
|
|||
|
|
@ -36,7 +36,7 @@ Open-source, self-hosted note-taking tool built for quick capture. Markdown-nati
|
|||
<p></p>
|
||||
|
||||
[**TestMu AI** - The world’s first full-stack Agentic AI Quality Engineering platform](https://www.testmuai.com/?utm_medium=sponsor&utm_source=memos)
|
||||
|
||||
|
||||
<a href="https://www.testmuai.com/?utm_medium=sponsor&utm_source=memos" target="_blank" rel="noopener">
|
||||
<img src="https://usememos.com/sponsors/testmu.svg" alt="TestMu AI" height="36" />
|
||||
</a>
|
||||
|
|
@ -44,7 +44,7 @@ Open-source, self-hosted note-taking tool built for quick capture. Markdown-nati
|
|||
<p></p>
|
||||
|
||||
[**SSD Nodes** - Affordable VPS hosting for self-hosters](https://ssdnodes.com/?utm_source=memos&utm_medium=sponsor)
|
||||
|
||||
|
||||
<a href="https://ssdnodes.com/?utm_source=memos&utm_medium=sponsor" target="_blank" rel="noopener">
|
||||
<img src="https://usememos.com/sponsors/ssd-nodes.svg" alt="SSD Nodes" height="72" />
|
||||
</a>
|
||||
|
|
|
|||
|
|
@ -0,0 +1,41 @@
|
|||
## Background & Context
|
||||
|
||||
User resources in Memos v1 are exposed through Connect/gRPC-Gateway handlers in `server/router/api/v1`, proto resource definitions in `proto/api/v1`, frontend profile flows in `web/src`, and MCP JSON helpers in `server/router/mcp`. The store schema already persists both an internal integer `id` and a unique `username` for each user. The GitHub issue reports that public user resource names such as `users/2` are still emitted across responses and nested user-scoped resources. Existing code already mixes identifier forms: `GetUser` accepts either `users/{id}` or `users/{username}`, the fileserver avatar route accepts either identifier, and the frontend profile page already enters the API through `users/{username}` before reusing the returned `user.name`.
|
||||
|
||||
## Issue Statement
|
||||
|
||||
Across the v1 API surface, canonical user resource names are currently constructed from `store.User.ID` rather than `store.User.Username`, and many handlers parse those emitted names back into integers for authorization and lookup. As a result, top-level user resources and nested user-scoped references in settings, stats, shortcuts, webhooks, notifications, memo creators, reactions, and MCP payloads expose sequential database IDs and couple downstream callers to integer-based user tokens in server-emitted names.
|
||||
|
||||
## Current State
|
||||
|
||||
- `store/user.go:26-42` defines `store.User` with both `ID int32` and `Username string`; `store/migration/sqlite/LATEST.sql:10-21` declares `username TEXT NOT NULL UNIQUE`.
|
||||
- `server/router/api/v1/user_service.go:72-102` handles `GetUser` by extracting `users/{id_or_username}` and resolving either a numeric ID or a username; `server/router/api/v1/user_service.go:914-937` still serializes `User.name` as `users/{id}` and derives avatar URLs from that name.
|
||||
- `server/router/api/v1/resource_name.go:67-89` has two different parsing paths: `ExtractUserIDFromName` only accepts numeric user tokens, while `extractUserIdentifierFromName` accepts either token and is currently only used by `GetUser`.
|
||||
- `server/router/api/v1/user_service.go:335-369`, `server/router/api/v1/user_service.go:372-460`, `server/router/api/v1/user_service.go:463-517`, `server/router/api/v1/user_service.go:536-676`, `server/router/api/v1/user_service.go:679-911`, and `server/router/api/v1/user_service.go:1400-1488` parse numeric user segments for settings, personal access tokens, webhooks, and notifications, and emit names such as `users/%d/settings/...`, `users/%d/webhooks/...`, and `users/%d/notifications/%d`.
|
||||
- `server/router/api/v1/shortcut_service.go:20-43` parses `users/{user}/shortcuts/{shortcut}` by converting the `user` segment to `int32`, and constructs shortcut names as `users/%d/shortcuts/%s`.
|
||||
- `server/router/api/v1/user_service_stats.go:63-65`, `server/router/api/v1/user_service_stats.go:113`, `server/router/api/v1/user_service_stats.go:132-145`, `server/router/api/v1/user_service_stats.go:214-223` emit `users/%d/stats` and `users/%d/memos/%d`, and resolve stats requests through numeric `ExtractUserIDFromName`.
|
||||
- `server/router/api/v1/memo_service_converter.go:26-37` serializes `Memo.creator` as `users/{id}`; `server/router/api/v1/reaction_service.go:154-164` serializes `Reaction.creator` as `users/{id}`; `server/router/api/v1/memo_service.go:636-643` and `server/router/api/v1/memo_service.go:815-845` parse `memo.Creator` through the numeric helper for inbox and webhook flows.
|
||||
- `server/router/mcp/tools_memo.go:75-86`, `server/router/mcp/tools_attachment.go:29-37`, and `server/router/mcp/tools_reaction.go:64-71` plus `server/router/mcp/tools_reaction.go:133-138` serialize creator fields as `users/{id}` in MCP tool output.
|
||||
- `server/router/fileserver/fileserver.go:153-181` and `server/router/fileserver/fileserver.go:533-539` currently resolve avatar requests by either numeric ID or username.
|
||||
- `proto/api/v1/user_service.proto:22-29` and `proto/api/v1/user_service.proto:247-256` document `GetUser` accepting both `users/{id}` and `users/{username}`. The same proto file defines the `User` resource at `proto/api/v1/user_service.proto:161-178` and nested user resource formats at `proto/api/v1/user_service.proto:307-317` and `proto/api/v1/user_service.proto:361-373`; example text still uses numeric user tokens such as `users/123/settings/GENERAL`.
|
||||
- `web/src/pages/UserProfile.tsx:74-86` requests `users/{username}` from the route param, and `web/src/layouts/MainLayout.tsx:37-48` stores the returned canonical `user.name` for later stats requests.
|
||||
|
||||
## Non-Goals
|
||||
|
||||
- Replacing internal `user.id` primary keys, foreign keys, or existing store schemas.
|
||||
- Introducing a new opaque UUID-based public user identifier.
|
||||
- Changing user discovery, public profile visibility, or authorization rules beyond how user resource names are parsed and emitted.
|
||||
- Adding username history, redirect, or alias preservation for old usernames after a rename.
|
||||
- Redesigning unrelated resource naming schemes such as memo, attachment, share, or identity-provider identifiers.
|
||||
|
||||
## Open Questions
|
||||
|
||||
- Which public surfaces are in scope for username-based canonical output? (default: all server-emitted v1 API and MCP payload fields that currently contain `users/{...}` resource names)
|
||||
- Should legacy numeric inputs continue to resolve on user-scoped endpoints beyond `GetUser`? (default: no, accept only username-based user resource names)
|
||||
- If a username changes, must previously emitted `users/{old-username}` names continue to resolve? (default: no additional alias or redirect layer; only the current username remains valid)
|
||||
- Should notification, webhook, shortcut, and personal-access-token child identifiers keep their existing child token formats while only the parent user token changes? (default: yes)
|
||||
- Does the issue include avatar URLs and other derived file paths that are built from `User.name`? (default: yes, because avatar URLs are emitted from the same canonical user name field)
|
||||
|
||||
## Scope
|
||||
|
||||
**L** — Current behavior spans `server/router/api/v1`, `server/router/mcp`, `server/router/fileserver`, `proto/api/v1`, frontend consumers in `web/src`, and the request parsers that turn user resource names back into internal IDs. Changing both emitted and accepted user resource names across those surfaces is a broad API contract change rather than a single local edit.
|
||||
|
|
@ -0,0 +1,63 @@
|
|||
## References
|
||||
|
||||
- [AIP-122: Resource names](https://google.aip.dev/122)
|
||||
- [AIP-123: Resource types](https://google.aip.dev/123)
|
||||
- [AIP-148: Standard fields](https://google.aip.dev/148)
|
||||
- [AIP-180: Backwards compatibility](https://google.aip.dev/180)
|
||||
- [Insecure Direct Object Reference Prevention Cheat Sheet](https://cheatsheetseries.owasp.org/cheatsheets/Insecure_Direct_Object_Reference_Prevention_Cheat_Sheet.html)
|
||||
- [REST API endpoints for users - GitHub Docs](https://docs.github.com/en/enterprise-server%403.19/rest/users/users)
|
||||
- [Users API - GitLab Docs](https://docs.gitlab.com/api/users/)
|
||||
- [API Usage - Gitea Documentation](https://docs.gitea.com/next/development/api-usage)
|
||||
|
||||
## Industry Baseline
|
||||
|
||||
`AIP-122: Resource names` and `AIP-148: Standard fields` treat `name` as the canonical identifier that clients store and reuse, and expect request `name` and `parent` fields to accept the same resource-name vocabulary across a service. `AIP-122` also allows aliases for lookup, but requires responses to emit the canonical resource name.
|
||||
|
||||
`REST API endpoints for users - GitHub Docs` and `API Usage - Gitea Documentation` use username-based public user paths and nested user-scoped routes, while keeping numeric or system-assigned identifiers as separate data or alternate endpoints when a durable internal identifier is required.
|
||||
|
||||
`Users API - GitLab Docs` shows a mixed-input compatibility pattern on some endpoints with `id_or_username`, which keeps older callers working while allowing username-oriented public routes.
|
||||
|
||||
`Insecure Direct Object Reference Prevention Cheat Sheet` treats enumerable numeric identifiers as a defense-in-depth concern, but not a substitute for authorization. Replacing `users/{id}` with `users/{username}` changes discoverability characteristics, but permission checks still have to enforce access from internal user IDs.
|
||||
|
||||
`AIP-180: Backwards compatibility` treats changes to resource-name format and server-generated field construction as breaking. Any design that changes emitted `User.name` values inside `v1` has to preserve as much request compatibility as possible and document the remaining response-format risk explicitly.
|
||||
|
||||
## Research Summary
|
||||
|
||||
Memos already has most of the prerequisites for username-based canonical names. The schema stores a unique username, `GetUser` already resolves either ID or username, the fileserver avatar route already uses an `identifier` abstraction, and the frontend profile page already starts from `users/{username}`. No database migration is required to identify users by username at the API boundary.
|
||||
|
||||
The current coupling problem is concentrated in two places. First, response builders serialize `users/{id}` in many modules, including memo conversion, stats, settings, shortcuts, notifications, webhooks, and MCP JSON helpers. Second, many request handlers assume they can parse a numeric ID back out of those names for authorization and storage lookups.
|
||||
|
||||
Research points to a common pattern of canonical public resource names plus server-side resolution to internal IDs. In Memos, switching the canonical token from numeric ID to username can reuse the existing unique username column and existing username lookups, but `AIP-123: Resource types` and `AIP-180: Backwards compatibility` still make clear that changing accepted and emitted resource-name formats inside `v1` is a breaking API contract change. That makes this design a deliberate contract replacement rather than a compatibility layer.
|
||||
|
||||
## Design Goals
|
||||
|
||||
- All server-emitted v1 and MCP response fields that serialize user resource names under `users/{...}` use the current username token instead of the numeric database ID.
|
||||
- User-scoped request fields that reference `users/{...}` accept username-based resource names only.
|
||||
- Authorization, ownership checks, inbox/webhook dispatch, and other internal workflows continue to operate on `store.User.ID` after resolving the public resource name.
|
||||
- List and batch endpoints avoid introducing per-item user lookups when serializing username-based names.
|
||||
- No database schema, foreign-key, or storage-key redesign is required.
|
||||
|
||||
## Non-Goals
|
||||
|
||||
- Replacing internal `user.id` primary keys, foreign keys, or existing store schemas.
|
||||
- Introducing a new opaque UUID-based public user identifier.
|
||||
- Changing user discovery, public profile visibility, or authorization rules beyond how user resource names are parsed and emitted.
|
||||
- Adding username history, redirect, or alias preservation for old usernames after a rename.
|
||||
- Redesigning unrelated resource naming schemes such as memo, attachment, share, or identity-provider identifiers.
|
||||
- Adding a new API version as part of this issue.
|
||||
|
||||
## Proposed Design
|
||||
|
||||
Introduce a single canonical user-name builder in the v1 API layer that serializes `users/{username}` from resolved user data, and route every public user-name emitter through it. This includes `convertUserFromStore`, memo and reaction creator fields, user stats, settings, shortcuts, webhooks, notifications, personal-access-token names, webhook payloads, avatar URLs derived from `User.name`, and the MCP JSON helpers. This satisfies the first design goal and aligns the public resource shape with `AIP-122: Resource names`.
|
||||
|
||||
Introduce a shared user-token resolver in `server/router/api/v1` that extracts the `users/{token}` segment, validates it as a username-form resource token, resolves the corresponding `store.User`, and then passes the resolved internal ID into permission checks and storage lookups. This replaces numeric-only parsing in helpers such as `ExtractUserIDFromName`, `ExtractUserIDAndSettingKeyFromName`, shortcut and webhook parsers, personal-access-token deletion, and notification parsing. The fileserver's current `getUserByIdentifier` behavior shows both lookup styles exist today, but the API-layer contract for this issue becomes username-only rather than dual-mode.
|
||||
|
||||
Keep child resource tokens unchanged and only change the user segment. For names such as `users/{user}/settings/{setting}`, `users/{user}/webhooks/{webhook}`, `users/{user}/notifications/{notification}`, `users/{user}/shortcuts/{shortcut}`, and `users/{user}/personalAccessTokens/{token}`, the parent `user` token is resolved from the username, while the child token keeps its existing format and storage mapping. This is narrower than redesigning child identifiers and keeps the issue bounded to the user-resource segment.
|
||||
|
||||
Use response-side user resolution strategies that match endpoint shape. Single-resource handlers can resolve one user directly and serialize the username immediately. List and batch handlers such as memo conversion, stats aggregation, notifications, and MCP list output should collect distinct user IDs first and resolve usernames once per response, reusing the store's existing user lookup path and cache where available. This keeps username-based output from turning into hidden N+1 query behavior and satisfies the performance goal without changing persistence.
|
||||
|
||||
Replace the public user-resource contract rather than extending it. Server-emitted `name`, `parent`, `creator`, and `sender` fields become username-based canonical output, and handlers that currently accept `users/{id}` are updated to require `users/{username}`. `AIP-180: Backwards compatibility` indicates that changing both the construction and accepted format of an existing resource name is a breaking change for clients that persist, compare, or generate old `name` values. The design therefore requires updated proto comments, API examples, handler tests, and release notes to make the new canonical form and the removed numeric form explicit.
|
||||
|
||||
Do not add a username alias table in this issue. If a username changes, newly serialized resource names use the current username, and previously emitted username-based names stop resolving unless they match the current username. This keeps the scope aligned with existing `UpdateUser` behavior and avoids introducing a new subsystem for historical username resolution. The alternative of adding permanent old-username aliases was rejected because it expands the problem from canonical serialization into identity-history management.
|
||||
|
||||
Do not solve this by adding a second public identifier field and leaving `User.name` numeric. `AIP-122: Resource names` treats `name` as the canonical resource identifier, and the GitHub issue is specifically about the public names currently emitted under `users/{id}`. Adding a second field would preserve the exposed sequential identifier in the canonical slot and fail the primary design goal. Likewise, introducing a new opaque UUID-based public identifier was rejected because the repository already has a unique username field and the issue is scoped to replacing numeric user resource names with that existing identifier.
|
||||
|
|
@ -0,0 +1,62 @@
|
|||
## Execution Log
|
||||
|
||||
### T1: Add username-only user resource helpers
|
||||
|
||||
**Status**: Completed
|
||||
**Files Changed**:
|
||||
- `server/router/api/v1/user_resource_name.go`
|
||||
- `server/router/api/v1/resource_name.go`
|
||||
- `server/router/api/v1/user_service.go`
|
||||
- `server/router/api/v1/test/user_resource_name_test.go`
|
||||
**Validation**: `go test -v ./server/router/api/v1/test -run 'TestUserResourceName'` — PASS
|
||||
**Path Corrections**: Tightened username-token validation so numeric-only `users/1` fails at the resource-name layer instead of falling through to `NotFound`.
|
||||
**Deviations**: None
|
||||
|
||||
### T2: Migrate user-scoped API handlers
|
||||
|
||||
**Status**: Completed
|
||||
**Files Changed**:
|
||||
- `server/router/api/v1/user_service.go`
|
||||
- `server/router/api/v1/shortcut_service.go`
|
||||
- `server/router/api/v1/user_service_stats.go`
|
||||
- `server/router/api/v1/test/shortcut_service_test.go`
|
||||
- `server/router/api/v1/test/user_service_stats_test.go`
|
||||
- `server/router/api/v1/test/user_notification_test.go`
|
||||
- `server/router/api/v1/test/user_service_registration_test.go`
|
||||
**Validation**: `go test -v ./server/router/api/v1/test -run 'Test(ListShortcuts|GetShortcut|CreateShortcut|UpdateShortcut|DeleteShortcut|ShortcutFiltering|ShortcutCRUDComplete|GetUserStats_TagCount|ListUserNotifications|UserRegistration)'` — PASS
|
||||
**Path Corrections**: Updated test fixtures to use valid username-form resource names (`users/testuser`, `users/test-user`) and corrected one stale registration-name expectation during the later broader suite rerun.
|
||||
**Deviations**: None
|
||||
|
||||
### T3: Migrate memo, reaction, MCP, and avatar user references
|
||||
|
||||
**Status**: Completed
|
||||
**Files Changed**:
|
||||
- `server/router/api/v1/memo_service_converter.go`
|
||||
- `server/router/api/v1/memo_service.go`
|
||||
- `server/router/api/v1/reaction_service.go`
|
||||
- `server/router/mcp/tools_memo.go`
|
||||
- `server/router/mcp/tools_attachment.go`
|
||||
- `server/router/mcp/tools_reaction.go`
|
||||
- `server/router/fileserver/fileserver.go`
|
||||
- `server/router/api/v1/test/memo_service_test.go`
|
||||
- `server/router/api/v1/test/reaction_service_test.go`
|
||||
**Validation**: `go test ./server/router/api/v1/... ./server/router/mcp/... ./server/router/fileserver/...` — PASS
|
||||
**Path Corrections**: Removed an unused fileserver import after the first package build failed; kept MCP tool helper signatures stable for undeclared callers and switched tool call sites to username-aware wrappers.
|
||||
**Deviations**: None
|
||||
|
||||
### T4: Update contract docs and regression tests
|
||||
|
||||
**Status**: Completed
|
||||
**Files Changed**:
|
||||
- `proto/api/v1/user_service.proto`
|
||||
- `proto/api/v1/shortcut_service.proto`
|
||||
- `web/src/layouts/MainLayout.tsx`
|
||||
- `web/src/components/MemoExplorer/ShortcutsSection.tsx`
|
||||
- `server/router/fileserver/README.md`
|
||||
**Validation**: `go test -v ./server/router/api/v1/test/...` — PASS
|
||||
**Path Corrections**: None
|
||||
**Deviations**: None
|
||||
|
||||
## Completion Declaration
|
||||
|
||||
All tasks completed successfully
|
||||
|
|
@ -0,0 +1,106 @@
|
|||
## Task List
|
||||
|
||||
Task Index
|
||||
T1: Add username-only user resource helpers [L] — T2: Migrate user-scoped API handlers [L] — T3: Migrate memo, reaction, MCP, and avatar user references [L] — T4: Update contract docs and regression tests [L]
|
||||
|
||||
### T1: Add username-only user resource helpers [L]
|
||||
|
||||
**Objective**: Establish one v1 API mechanism for serializing `users/{username}` and resolving username-based user resource names back to internal user records, including root `GetUser` handling.
|
||||
**Size**: L (multiple files, shared identifier logic used across handlers)
|
||||
**Files**:
|
||||
- Create: `server/router/api/v1/user_resource_name.go`
|
||||
- Modify: `server/router/api/v1/resource_name.go`
|
||||
- Modify: `server/router/api/v1/user_service.go`
|
||||
- Test: `server/router/api/v1/test/user_resource_name_test.go`
|
||||
**Implementation**:
|
||||
1. In `server/router/api/v1/user_resource_name.go`: add the shared helper surface for canonical user-name construction, extracting the `users/{token}` segment, validating the username-form token, and resolving the corresponding `store.User`.
|
||||
2. In `server/router/api/v1/resource_name.go`: replace `ExtractUserIDFromName()`’s numeric-only behavior with username-oriented resolution helpers or thin wrappers that delegate to the new shared module.
|
||||
3. In `server/router/api/v1/user_service.go`: update `GetUser()` (~lines 72-102) and `convertUserFromStore()` (~lines 914-937) to use username-only resource names and reject legacy numeric `users/{id}` requests.
|
||||
4. In `server/router/api/v1/test/user_resource_name_test.go`: add direct coverage for `GetUser users/{username}` success, canonical `User.name == users/{username}`, and rejection of `users/{id}`.
|
||||
**Boundaries**: Do not migrate nested user-scoped handlers, memo/reaction emitters, MCP output, or fileserver behavior in this task.
|
||||
**Dependencies**: None
|
||||
**Expected Outcome**: Shared username-only helper logic exists, root user resources serialize as `users/{username}`, and root numeric user-name requests fail.
|
||||
**Validation**: `go test -v ./server/router/api/v1/test -run 'TestUserResourceName'` — expected output includes `PASS` and `ok`
|
||||
|
||||
### T2: Migrate user-scoped API handlers [L]
|
||||
|
||||
**Objective**: Convert user-scoped v1 handlers and nested resource emitters to require `users/{username}` while continuing to authorize and store by resolved internal user ID.
|
||||
**Size**: L (multiple handlers in one large service plus shortcut and stats code)
|
||||
**Files**:
|
||||
- Modify: `server/router/api/v1/user_service.go`
|
||||
- Modify: `server/router/api/v1/shortcut_service.go`
|
||||
- Modify: `server/router/api/v1/user_service_stats.go`
|
||||
- Test: `server/router/api/v1/test/shortcut_service_test.go`
|
||||
- Test: `server/router/api/v1/test/user_service_stats_test.go`
|
||||
- Test: `server/router/api/v1/test/user_notification_test.go`
|
||||
- Test: `server/router/api/v1/test/user_service_registration_test.go`
|
||||
**Implementation**:
|
||||
1. In `server/router/api/v1/user_service.go`: update settings, PAT, webhook, and notification parsing/emission paths (~lines 335-911 and ~1400-1488) to resolve `users/{username}` and emit username-based parent/child resource names.
|
||||
2. In `server/router/api/v1/shortcut_service.go`: update shortcut name parsing and construction (~lines 20-43) plus handler entry points to use username parents and nested names.
|
||||
3. In `server/router/api/v1/user_service_stats.go`: update stats request parsing and `UserStats.name` / `PinnedMemos` serialization (~lines 63-65, 113, 132-145, 214-223) to use usernames.
|
||||
4. In the listed tests: replace numeric user-name inputs with username-based parents, assert username-based emitted names, and add numeric-request rejection coverage for representative user-scoped endpoints.
|
||||
**Boundaries**: Do not change memo/reaction creator fields, MCP JSON output, or fileserver avatar routing in this task.
|
||||
**Dependencies**: T1
|
||||
**Expected Outcome**: User settings, notifications, shortcuts, stats, PATs, and webhooks all accept only `users/{username}` and emit only username-based user resource names.
|
||||
**Validation**: `go test -v ./server/router/api/v1/test -run 'Test(ListShortcuts|GetShortcut|CreateShortcut|UpdateShortcut|DeleteShortcut|ShortcutFiltering|ShortcutCRUDComplete|GetUserStats_TagCount|ListUserNotifications|UserRegistration)'` — expected output includes `PASS` and `ok`
|
||||
|
||||
### T3: Migrate memo, reaction, MCP, and avatar user references [L]
|
||||
|
||||
**Objective**: Remove numeric user resource names from memo/reaction-related API responses, dependent webhook/inbox flows, MCP JSON output, and avatar URLs/routing.
|
||||
**Size**: L (cross-package serialization and lookup changes, including response-side user resolution)
|
||||
**Files**:
|
||||
- Modify: `server/router/api/v1/memo_service_converter.go`
|
||||
- Modify: `server/router/api/v1/memo_service.go`
|
||||
- Modify: `server/router/api/v1/reaction_service.go`
|
||||
- Modify: `server/router/mcp/tools_memo.go`
|
||||
- Modify: `server/router/mcp/tools_attachment.go`
|
||||
- Modify: `server/router/mcp/tools_reaction.go`
|
||||
- Modify: `server/router/fileserver/fileserver.go`
|
||||
- Test: `server/router/api/v1/test/memo_service_test.go`
|
||||
- Test: `server/router/api/v1/test/reaction_service_test.go`
|
||||
**Implementation**:
|
||||
1. In `server/router/api/v1/memo_service_converter.go`: update `convertMemoFromStore()` (~lines 16-73) to serialize `Memo.creator` from resolved usernames rather than numeric IDs, using response-side batching or shared lookup helpers so list responses do not regress into hidden per-item lookups.
|
||||
2. In `server/router/api/v1/reaction_service.go`: update `convertReactionFromStore()` (~lines 154-164) to emit username-based creators.
|
||||
3. In `server/router/api/v1/memo_service.go`: update memo comment, webhook dispatch, and webhook payload helpers (~lines 636-643 and 815-845) to resolve username-based memo creators before using internal IDs.
|
||||
4. In `server/router/mcp/tools_memo.go`, `server/router/mcp/tools_attachment.go`, and `server/router/mcp/tools_reaction.go`: replace `users/%d` creator serialization with username-based values.
|
||||
5. In `server/router/fileserver/fileserver.go`: change avatar lookup to accept username identifiers only and ensure avatar URLs derived from `User.name` continue to resolve under `users/{username}`.
|
||||
6. In the listed tests: update creator assertions to `users/{username}` and add representative rejection coverage where numeric user names previously flowed through memo/reaction-related paths.
|
||||
**Boundaries**: Do not update proto comments, README examples, or frontend comments in this task.
|
||||
**Dependencies**: T1
|
||||
**Expected Outcome**: Memo/reaction creators, webhook payload creators, MCP creator fields, and avatar-derived user paths no longer expose numeric user IDs.
|
||||
**Validation**: `go test ./server/router/api/v1/... ./server/router/mcp/... ./server/router/fileserver/...` — expected output includes `ok` for all touched packages
|
||||
|
||||
### T4: Update contract docs and regression tests [L]
|
||||
|
||||
**Objective**: Align public contract comments/examples and the final regression suite with the username-only user resource-name contract.
|
||||
**Size**: L (multiple contract/documentation files plus end-to-end regression coverage)
|
||||
**Files**:
|
||||
- Modify: `proto/api/v1/user_service.proto`
|
||||
- Modify: `proto/api/v1/shortcut_service.proto`
|
||||
- Modify: `web/src/layouts/MainLayout.tsx`
|
||||
- Modify: `web/src/components/MemoExplorer/ShortcutsSection.tsx`
|
||||
- Modify: `server/router/fileserver/README.md`
|
||||
- Modify: `server/router/api/v1/test/user_resource_name_test.go`
|
||||
- Modify: `server/router/api/v1/test/shortcut_service_test.go`
|
||||
- Modify: `server/router/api/v1/test/user_service_stats_test.go`
|
||||
- Modify: `server/router/api/v1/test/user_notification_test.go`
|
||||
- Modify: `server/router/api/v1/test/memo_service_test.go`
|
||||
- Modify: `server/router/api/v1/test/reaction_service_test.go`
|
||||
- Modify: `server/router/api/v1/test/user_service_registration_test.go`
|
||||
**Implementation**:
|
||||
1. In `proto/api/v1/user_service.proto` and `proto/api/v1/shortcut_service.proto`: rewrite resource-name comments and examples so they document username-only user resource names and remove `users/{id}` examples.
|
||||
2. In `web/src/layouts/MainLayout.tsx` and `web/src/components/MemoExplorer/ShortcutsSection.tsx`: update inline comments/examples that still describe numeric user resource names.
|
||||
3. In `server/router/fileserver/README.md`: replace numeric avatar examples with username-based examples.
|
||||
4. In the listed test files: finish any remaining request/response assertions so the suite consistently encodes the username-only contract and explicitly rejects numeric user resource names where that contract is externally visible.
|
||||
**Boundaries**: Do not add schema migrations, generated proto output refreshes, or username-history behavior.
|
||||
**Dependencies**: T2, T3
|
||||
**Expected Outcome**: Source comments, examples, and regression tests all describe and enforce a username-only `users/{username}` public contract.
|
||||
**Validation**: `go test -v ./server/router/api/v1/test/...` — expected output includes `PASS` and `ok`
|
||||
|
||||
## Out-of-Scope Tasks
|
||||
|
||||
- Database schema or migration changes for the `user` table or foreign keys.
|
||||
- Username history, alias, redirect, or backward-compatibility layers.
|
||||
- A new opaque public user identifier or a new API version.
|
||||
- Opportunistic refactors outside the files listed above.
|
||||
- Generated code refreshes (`buf generate`) unless a later approved plan revision explicitly requires schema changes.
|
||||
|
|
@ -1,12 +1,14 @@
|
|||
# Memo Filter Engine
|
||||
|
||||
This package houses the memo-only filter engine that turns CEL expressions into
|
||||
SQL fragments. The engine follows a three phase pipeline inspired by systems
|
||||
This package houses the memo-only filter engine that turns standard CEL syntax
|
||||
into SQL fragments for the subset of expressions supported by the memo schema.
|
||||
The engine follows a three phase pipeline inspired by systems
|
||||
such as Calcite or Prisma:
|
||||
|
||||
1. **Parsing** – CEL expressions are parsed with `cel-go` and validated against
|
||||
the memo-specific environment declared in `schema.go`. Only fields that
|
||||
exist in the schema can surface in the filter.
|
||||
exist in the schema can surface in the filter, and non-standard legacy
|
||||
coercions are rejected.
|
||||
2. **Normalization** – the raw CEL AST is converted into an intermediate
|
||||
representation (IR) defined in `ir.go`. The IR is a dialect-agnostic tree of
|
||||
conditions (logical operators, comparisons, list membership, etc.). This
|
||||
|
|
|
|||
|
|
@ -2,7 +2,6 @@ package filter
|
|||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
|
|
@ -45,8 +44,6 @@ func (e *Engine) Compile(_ context.Context, filter string) (*Program, error) {
|
|||
return nil, errors.New("filter expression is empty")
|
||||
}
|
||||
|
||||
filter = normalizeLegacyFilter(filter)
|
||||
|
||||
ast, issues := e.env.Compile(filter)
|
||||
if issues != nil && issues.Err() != nil {
|
||||
return nil, errors.Wrap(issues.Err(), "failed to compile filter")
|
||||
|
|
@ -119,73 +116,3 @@ func DefaultAttachmentEngine() (*Engine, error) {
|
|||
})
|
||||
return defaultAttachmentInst, defaultAttachmentErr
|
||||
}
|
||||
|
||||
func normalizeLegacyFilter(expr string) string {
|
||||
expr = rewriteNumericLogicalOperand(expr, "&&")
|
||||
expr = rewriteNumericLogicalOperand(expr, "||")
|
||||
return expr
|
||||
}
|
||||
|
||||
func rewriteNumericLogicalOperand(expr, op string) string {
|
||||
var builder strings.Builder
|
||||
n := len(expr)
|
||||
i := 0
|
||||
var inQuote rune
|
||||
|
||||
for i < n {
|
||||
ch := expr[i]
|
||||
|
||||
if inQuote != 0 {
|
||||
builder.WriteByte(ch)
|
||||
if ch == '\\' && i+1 < n {
|
||||
builder.WriteByte(expr[i+1])
|
||||
i += 2
|
||||
continue
|
||||
}
|
||||
if ch == byte(inQuote) {
|
||||
inQuote = 0
|
||||
}
|
||||
i++
|
||||
continue
|
||||
}
|
||||
|
||||
if ch == '\'' || ch == '"' {
|
||||
inQuote = rune(ch)
|
||||
builder.WriteByte(ch)
|
||||
i++
|
||||
continue
|
||||
}
|
||||
|
||||
if strings.HasPrefix(expr[i:], op) {
|
||||
builder.WriteString(op)
|
||||
i += len(op)
|
||||
|
||||
// Preserve whitespace following the operator.
|
||||
wsStart := i
|
||||
for i < n && (expr[i] == ' ' || expr[i] == '\t') {
|
||||
i++
|
||||
}
|
||||
builder.WriteString(expr[wsStart:i])
|
||||
|
||||
signStart := i
|
||||
if i < n && (expr[i] == '+' || expr[i] == '-') {
|
||||
i++
|
||||
}
|
||||
for i < n && expr[i] >= '0' && expr[i] <= '9' {
|
||||
i++
|
||||
}
|
||||
if i > signStart {
|
||||
numLiteral := expr[signStart:i]
|
||||
fmt.Fprintf(&builder, "(%s != 0)", numLiteral)
|
||||
} else {
|
||||
builder.WriteString(expr[signStart:i])
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
builder.WriteByte(ch)
|
||||
i++
|
||||
}
|
||||
|
||||
return builder.String()
|
||||
}
|
||||
|
|
|
|||
|
|
@ -0,0 +1,39 @@
|
|||
package filter
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestCompileAcceptsStandardTagEqualityPredicate(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
engine, err := NewEngine(NewSchema())
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = engine.Compile(context.Background(), `tags.exists(t, t == "1231")`)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestCompileRejectsLegacyNumericLogicalOperand(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
engine, err := NewEngine(NewSchema())
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = engine.Compile(context.Background(), `pinned && 1`)
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "failed to compile filter")
|
||||
}
|
||||
|
||||
func TestCompileRejectsNonBooleanTopLevelConstant(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
engine, err := NewEngine(NewSchema())
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = engine.Compile(context.Background(), `1`)
|
||||
require.EqualError(t, err, "filter must evaluate to a boolean value")
|
||||
}
|
||||
|
|
@ -157,3 +157,10 @@ type ContainsPredicate struct {
|
|||
}
|
||||
|
||||
func (*ContainsPredicate) isPredicateExpr() {}
|
||||
|
||||
// EqualsPredicate represents t == "value".
|
||||
type EqualsPredicate struct {
|
||||
Value string
|
||||
}
|
||||
|
||||
func (*EqualsPredicate) isPredicateExpr() {}
|
||||
|
|
|
|||
|
|
@ -16,16 +16,10 @@ func buildCondition(expr *exprv1.Expr, schema Schema) (Condition, error) {
|
|||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
switch v := val.(type) {
|
||||
case bool:
|
||||
if v, ok := val.(bool); ok {
|
||||
return &ConstantCondition{Value: v}, nil
|
||||
case int64:
|
||||
return &ConstantCondition{Value: v != 0}, nil
|
||||
case float64:
|
||||
return &ConstantCondition{Value: v != 0}, nil
|
||||
default:
|
||||
return nil, errors.New("filter must evaluate to a boolean value")
|
||||
}
|
||||
return nil, errors.New("filter must evaluate to a boolean value")
|
||||
case *exprv1.Expr_IdentExpr:
|
||||
name := v.IdentExpr.GetName()
|
||||
field, ok := schema.Field(name)
|
||||
|
|
@ -504,6 +498,8 @@ func extractPredicate(comp *exprv1.Expr_Comprehension, _ Schema) (PredicateExpr,
|
|||
|
||||
// Handle different predicate functions
|
||||
switch predicateCall.Function {
|
||||
case "_==_":
|
||||
return buildEqualsPredicate(predicateCall, comp.IterVar)
|
||||
case "startsWith":
|
||||
return buildStartsWithPredicate(predicateCall, comp.IterVar)
|
||||
case "endsWith":
|
||||
|
|
@ -511,10 +507,44 @@ func extractPredicate(comp *exprv1.Expr_Comprehension, _ Schema) (PredicateExpr,
|
|||
case "contains":
|
||||
return buildContainsPredicate(predicateCall, comp.IterVar)
|
||||
default:
|
||||
return nil, errors.Errorf("unsupported predicate function %q in comprehension (supported: startsWith, endsWith, contains)", predicateCall.Function)
|
||||
return nil, errors.Errorf(`unsupported predicate function %q in comprehension (supported: ==, startsWith, endsWith, contains)`, predicateCall.Function)
|
||||
}
|
||||
}
|
||||
|
||||
// buildEqualsPredicate extracts the value from t == "value".
|
||||
func buildEqualsPredicate(call *exprv1.Expr_Call, iterVar string) (PredicateExpr, error) {
|
||||
if len(call.Args) != 2 {
|
||||
return nil, errors.New("equality predicate expects exactly two arguments")
|
||||
}
|
||||
|
||||
var constExpr *exprv1.Expr
|
||||
switch {
|
||||
case isIterVarExpr(call.Args[0], iterVar):
|
||||
constExpr = call.Args[1]
|
||||
case isIterVarExpr(call.Args[1], iterVar):
|
||||
constExpr = call.Args[0]
|
||||
default:
|
||||
return nil, errors.Errorf("equality predicate must compare against the iteration variable %q", iterVar)
|
||||
}
|
||||
|
||||
value, err := getConstValue(constExpr)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "equality argument must be a constant string")
|
||||
}
|
||||
|
||||
valueStr, ok := value.(string)
|
||||
if !ok {
|
||||
return nil, errors.New("equality argument must be a string")
|
||||
}
|
||||
|
||||
return &EqualsPredicate{Value: valueStr}, nil
|
||||
}
|
||||
|
||||
func isIterVarExpr(expr *exprv1.Expr, iterVar string) bool {
|
||||
target := expr.GetIdentExpr()
|
||||
return target != nil && target.GetName() == iterVar
|
||||
}
|
||||
|
||||
// buildStartsWithPredicate extracts the pattern from t.startsWith("prefix").
|
||||
func buildStartsWithPredicate(call *exprv1.Expr_Call, iterVar string) (PredicateExpr, error) {
|
||||
// Verify the target is the iteration variable
|
||||
|
|
|
|||
|
|
@ -480,6 +480,8 @@ func (r *renderer) renderListComprehension(cond *ListComprehensionCondition) (re
|
|||
|
||||
// Render based on predicate type
|
||||
switch pred := cond.Predicate.(type) {
|
||||
case *EqualsPredicate:
|
||||
return r.renderTagEquals(field, pred.Value, cond.Kind)
|
||||
case *StartsWithPredicate:
|
||||
return r.renderTagStartsWith(field, pred.Prefix, cond.Kind)
|
||||
case *EndsWithPredicate:
|
||||
|
|
@ -491,6 +493,22 @@ func (r *renderer) renderListComprehension(cond *ListComprehensionCondition) (re
|
|||
}
|
||||
}
|
||||
|
||||
// renderTagEquals generates SQL for tags.exists(t, t == "value").
|
||||
func (r *renderer) renderTagEquals(field Field, value string, _ ComprehensionKind) (renderResult, error) {
|
||||
arrayExpr := jsonArrayExpr(r.dialect, field)
|
||||
|
||||
switch r.dialect {
|
||||
case DialectSQLite, DialectMySQL:
|
||||
exactMatch := r.buildJSONArrayLike(arrayExpr, fmt.Sprintf(`%%"%s"%%`, value))
|
||||
return renderResult{sql: r.wrapWithNullCheck(arrayExpr, exactMatch)}, nil
|
||||
case DialectPostgres:
|
||||
exactMatch := fmt.Sprintf("%s @> jsonb_build_array(%s::json)", arrayExpr, r.addArg(fmt.Sprintf(`"%s"`, value)))
|
||||
return renderResult{sql: r.wrapWithNullCheck(arrayExpr, exactMatch)}, nil
|
||||
default:
|
||||
return renderResult{}, errors.Errorf("unsupported dialect %s", r.dialect)
|
||||
}
|
||||
}
|
||||
|
||||
// renderTagStartsWith generates SQL for tags.exists(t, t.startsWith("prefix")).
|
||||
func (r *renderer) renderTagStartsWith(field Field, prefix string, _ ComprehensionKind) (renderResult, error) {
|
||||
arrayExpr := jsonArrayExpr(r.dialect, field)
|
||||
|
|
|
|||
|
|
@ -108,6 +108,21 @@ func NewSchema() Schema {
|
|||
SupportsContains: true,
|
||||
Expressions: map[DialectName]string{},
|
||||
},
|
||||
"creator": {
|
||||
Name: "creator",
|
||||
Kind: FieldKindScalar,
|
||||
Type: FieldTypeString,
|
||||
Column: Column{Table: "memo_creator", Name: "username"},
|
||||
Expressions: map[DialectName]string{
|
||||
DialectSQLite: "('users/' || %s)",
|
||||
DialectMySQL: "CONCAT('users/', %s)",
|
||||
DialectPostgres: "('users/' || %s)",
|
||||
},
|
||||
AllowedComparisonOps: map[ComparisonOperator]bool{
|
||||
CompareEq: true,
|
||||
CompareNeq: true,
|
||||
},
|
||||
},
|
||||
"creator_id": {
|
||||
Name: "creator_id",
|
||||
Kind: FieldKindScalar,
|
||||
|
|
@ -228,6 +243,7 @@ func NewSchema() Schema {
|
|||
|
||||
envOptions := []cel.EnvOption{
|
||||
cel.Variable("content", cel.StringType),
|
||||
cel.Variable("creator", cel.StringType),
|
||||
cel.Variable("creator_id", cel.IntType),
|
||||
cel.Variable("created_ts", cel.IntType),
|
||||
cel.Variable("updated_ts", cel.IntType),
|
||||
|
|
|
|||
|
|
@ -167,7 +167,8 @@ message InstanceSetting {
|
|||
|
||||
// Metadata for a tag.
|
||||
message TagMetadata {
|
||||
// Background color for the tag label.
|
||||
// Optional background color for the tag label.
|
||||
// When unset, the default tag color is used.
|
||||
google.type.Color background_color = 1;
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -52,13 +52,13 @@ service ShortcutService {
|
|||
message Shortcut {
|
||||
option (google.api.resource) = {
|
||||
type: "memos.api.v1/Shortcut"
|
||||
pattern: "users/{user}/shortcuts/{shortcut}"
|
||||
pattern: "users/{username}/shortcuts/{shortcut}"
|
||||
singular: "shortcut"
|
||||
plural: "shortcuts"
|
||||
};
|
||||
|
||||
// The resource name of the shortcut.
|
||||
// Format: users/{user}/shortcuts/{shortcut}
|
||||
// Format: users/{username}/shortcuts/{shortcut}
|
||||
string name = 1 [(google.api.field_behavior) = IDENTIFIER];
|
||||
|
||||
// The title of the shortcut.
|
||||
|
|
@ -70,7 +70,7 @@ message Shortcut {
|
|||
|
||||
message ListShortcutsRequest {
|
||||
// Required. The parent resource where shortcuts are listed.
|
||||
// Format: users/{user}
|
||||
// Format: users/{username}
|
||||
string parent = 1 [
|
||||
(google.api.field_behavior) = REQUIRED,
|
||||
(google.api.resource_reference) = {child_type: "memos.api.v1/Shortcut"}
|
||||
|
|
@ -84,7 +84,7 @@ message ListShortcutsResponse {
|
|||
|
||||
message GetShortcutRequest {
|
||||
// Required. The resource name of the shortcut to retrieve.
|
||||
// Format: users/{user}/shortcuts/{shortcut}
|
||||
// Format: users/{username}/shortcuts/{shortcut}
|
||||
string name = 1 [
|
||||
(google.api.field_behavior) = REQUIRED,
|
||||
(google.api.resource_reference) = {type: "memos.api.v1/Shortcut"}
|
||||
|
|
@ -93,7 +93,7 @@ message GetShortcutRequest {
|
|||
|
||||
message CreateShortcutRequest {
|
||||
// Required. The parent resource where this shortcut will be created.
|
||||
// Format: users/{user}
|
||||
// Format: users/{username}
|
||||
string parent = 1 [
|
||||
(google.api.field_behavior) = REQUIRED,
|
||||
(google.api.resource_reference) = {child_type: "memos.api.v1/Shortcut"}
|
||||
|
|
@ -116,7 +116,7 @@ message UpdateShortcutRequest {
|
|||
|
||||
message DeleteShortcutRequest {
|
||||
// Required. The resource name of the shortcut to delete.
|
||||
// Format: users/{user}/shortcuts/{shortcut}
|
||||
// Format: users/{username}/shortcuts/{shortcut}
|
||||
string name = 1 [
|
||||
(google.api.field_behavior) = REQUIRED,
|
||||
(google.api.resource_reference) = {type: "memos.api.v1/Shortcut"}
|
||||
|
|
|
|||
|
|
@ -19,10 +19,8 @@ service UserService {
|
|||
option (google.api.http) = {get: "/api/v1/users"};
|
||||
}
|
||||
|
||||
// GetUser gets a user by ID or username.
|
||||
// Supports both numeric IDs and username strings:
|
||||
// - users/{id} (e.g., users/101)
|
||||
// - users/{username} (e.g., users/steven)
|
||||
// GetUser gets a user by username.
|
||||
// Format: users/{username} (e.g., users/steven)
|
||||
rpc GetUser(GetUserRequest) returns (User) {
|
||||
option (google.api.http) = {get: "/api/v1/{name=users/*}"};
|
||||
option (google.api.method_signature) = "name";
|
||||
|
|
@ -246,10 +244,7 @@ message ListUsersResponse {
|
|||
|
||||
message GetUserRequest {
|
||||
// Required. The resource name of the user.
|
||||
// Supports both numeric IDs and username strings:
|
||||
// - users/{id} (e.g., users/101)
|
||||
// - users/{username} (e.g., users/steven)
|
||||
// Format: users/{id_or_username}
|
||||
// Format: users/{username}
|
||||
string name = 1 [
|
||||
(google.api.field_behavior) = REQUIRED,
|
||||
(google.api.resource_reference) = {type: "memos.api.v1/User"}
|
||||
|
|
@ -362,14 +357,14 @@ message ListAllUserStatsResponse {
|
|||
message UserSetting {
|
||||
option (google.api.resource) = {
|
||||
type: "memos.api.v1/UserSetting"
|
||||
pattern: "users/{user}/settings/{setting}"
|
||||
pattern: "users/{username}/settings/{setting}"
|
||||
singular: "userSetting"
|
||||
plural: "userSettings"
|
||||
};
|
||||
|
||||
// The name of the user setting.
|
||||
// Format: users/{user}/settings/{setting}, {setting} is the key for the setting.
|
||||
// For example, "users/123/settings/GENERAL" for general settings.
|
||||
// Format: users/{username}/settings/{setting}, {setting} is the key for the setting.
|
||||
// For example, "users/steven/settings/GENERAL" for general settings.
|
||||
string name = 1 [(google.api.field_behavior) = IDENTIFIER];
|
||||
|
||||
oneof value {
|
||||
|
|
|
|||
|
|
@ -95,10 +95,8 @@ const (
|
|||
type UserServiceClient interface {
|
||||
// ListUsers returns a list of users.
|
||||
ListUsers(context.Context, *connect.Request[v1.ListUsersRequest]) (*connect.Response[v1.ListUsersResponse], error)
|
||||
// GetUser gets a user by ID or username.
|
||||
// Supports both numeric IDs and username strings:
|
||||
// - users/{id} (e.g., users/101)
|
||||
// - users/{username} (e.g., users/steven)
|
||||
// GetUser gets a user by username.
|
||||
// Format: users/{username} (e.g., users/steven)
|
||||
GetUser(context.Context, *connect.Request[v1.GetUserRequest]) (*connect.Response[v1.User], error)
|
||||
// CreateUser creates a new user.
|
||||
CreateUser(context.Context, *connect.Request[v1.CreateUserRequest]) (*connect.Response[v1.User], error)
|
||||
|
|
@ -402,10 +400,8 @@ func (c *userServiceClient) DeleteUserNotification(ctx context.Context, req *con
|
|||
type UserServiceHandler interface {
|
||||
// ListUsers returns a list of users.
|
||||
ListUsers(context.Context, *connect.Request[v1.ListUsersRequest]) (*connect.Response[v1.ListUsersResponse], error)
|
||||
// GetUser gets a user by ID or username.
|
||||
// Supports both numeric IDs and username strings:
|
||||
// - users/{id} (e.g., users/101)
|
||||
// - users/{username} (e.g., users/steven)
|
||||
// GetUser gets a user by username.
|
||||
// Format: users/{username} (e.g., users/steven)
|
||||
GetUser(context.Context, *connect.Request[v1.GetUserRequest]) (*connect.Response[v1.User], error)
|
||||
// CreateUser creates a new user.
|
||||
CreateUser(context.Context, *connect.Request[v1.CreateUserRequest]) (*connect.Response[v1.User], error)
|
||||
|
|
|
|||
|
|
@ -759,7 +759,8 @@ func (x *InstanceSetting_MemoRelatedSetting) GetReactions() []string {
|
|||
// Metadata for a tag.
|
||||
type InstanceSetting_TagMetadata struct {
|
||||
state protoimpl.MessageState `protogen:"open.v1"`
|
||||
// Background color for the tag label.
|
||||
// Optional background color for the tag label.
|
||||
// When unset, the default tag color is used.
|
||||
BackgroundColor *color.Color `protobuf:"bytes,1,opt,name=background_color,json=backgroundColor,proto3" json:"background_color,omitempty"`
|
||||
unknownFields protoimpl.UnknownFields
|
||||
sizeCache protoimpl.SizeCache
|
||||
|
|
|
|||
|
|
@ -27,7 +27,7 @@ const (
|
|||
type Shortcut struct {
|
||||
state protoimpl.MessageState `protogen:"open.v1"`
|
||||
// The resource name of the shortcut.
|
||||
// Format: users/{user}/shortcuts/{shortcut}
|
||||
// Format: users/{username}/shortcuts/{shortcut}
|
||||
Name string `protobuf:"bytes,1,opt,name=name,proto3" json:"name,omitempty"`
|
||||
// The title of the shortcut.
|
||||
Title string `protobuf:"bytes,2,opt,name=title,proto3" json:"title,omitempty"`
|
||||
|
|
@ -91,7 +91,7 @@ func (x *Shortcut) GetFilter() string {
|
|||
type ListShortcutsRequest struct {
|
||||
state protoimpl.MessageState `protogen:"open.v1"`
|
||||
// Required. The parent resource where shortcuts are listed.
|
||||
// Format: users/{user}
|
||||
// Format: users/{username}
|
||||
Parent string `protobuf:"bytes,1,opt,name=parent,proto3" json:"parent,omitempty"`
|
||||
unknownFields protoimpl.UnknownFields
|
||||
sizeCache protoimpl.SizeCache
|
||||
|
|
@ -182,7 +182,7 @@ func (x *ListShortcutsResponse) GetShortcuts() []*Shortcut {
|
|||
type GetShortcutRequest struct {
|
||||
state protoimpl.MessageState `protogen:"open.v1"`
|
||||
// Required. The resource name of the shortcut to retrieve.
|
||||
// Format: users/{user}/shortcuts/{shortcut}
|
||||
// Format: users/{username}/shortcuts/{shortcut}
|
||||
Name string `protobuf:"bytes,1,opt,name=name,proto3" json:"name,omitempty"`
|
||||
unknownFields protoimpl.UnknownFields
|
||||
sizeCache protoimpl.SizeCache
|
||||
|
|
@ -228,7 +228,7 @@ func (x *GetShortcutRequest) GetName() string {
|
|||
type CreateShortcutRequest struct {
|
||||
state protoimpl.MessageState `protogen:"open.v1"`
|
||||
// Required. The parent resource where this shortcut will be created.
|
||||
// Format: users/{user}
|
||||
// Format: users/{username}
|
||||
Parent string `protobuf:"bytes,1,opt,name=parent,proto3" json:"parent,omitempty"`
|
||||
// Required. The shortcut to create.
|
||||
Shortcut *Shortcut `protobuf:"bytes,2,opt,name=shortcut,proto3" json:"shortcut,omitempty"`
|
||||
|
|
@ -346,7 +346,7 @@ func (x *UpdateShortcutRequest) GetUpdateMask() *fieldmaskpb.FieldMask {
|
|||
type DeleteShortcutRequest struct {
|
||||
state protoimpl.MessageState `protogen:"open.v1"`
|
||||
// Required. The resource name of the shortcut to delete.
|
||||
// Format: users/{user}/shortcuts/{shortcut}
|
||||
// Format: users/{username}/shortcuts/{shortcut}
|
||||
Name string `protobuf:"bytes,1,opt,name=name,proto3" json:"name,omitempty"`
|
||||
unknownFields protoimpl.UnknownFields
|
||||
sizeCache protoimpl.SizeCache
|
||||
|
|
@ -393,12 +393,12 @@ var File_api_v1_shortcut_service_proto protoreflect.FileDescriptor
|
|||
|
||||
const file_api_v1_shortcut_service_proto_rawDesc = "" +
|
||||
"\n" +
|
||||
"\x1dapi/v1/shortcut_service.proto\x12\fmemos.api.v1\x1a\x1cgoogle/api/annotations.proto\x1a\x17google/api/client.proto\x1a\x1fgoogle/api/field_behavior.proto\x1a\x19google/api/resource.proto\x1a\x1bgoogle/protobuf/empty.proto\x1a google/protobuf/field_mask.proto\"\xaf\x01\n" +
|
||||
"\x1dapi/v1/shortcut_service.proto\x12\fmemos.api.v1\x1a\x1cgoogle/api/annotations.proto\x1a\x17google/api/client.proto\x1a\x1fgoogle/api/field_behavior.proto\x1a\x19google/api/resource.proto\x1a\x1bgoogle/protobuf/empty.proto\x1a google/protobuf/field_mask.proto\"\xb3\x01\n" +
|
||||
"\bShortcut\x12\x17\n" +
|
||||
"\x04name\x18\x01 \x01(\tB\x03\xe0A\bR\x04name\x12\x19\n" +
|
||||
"\x05title\x18\x02 \x01(\tB\x03\xe0A\x02R\x05title\x12\x1b\n" +
|
||||
"\x06filter\x18\x03 \x01(\tB\x03\xe0A\x01R\x06filter:R\xeaAO\n" +
|
||||
"\x15memos.api.v1/Shortcut\x12!users/{user}/shortcuts/{shortcut}*\tshortcuts2\bshortcut\"M\n" +
|
||||
"\x06filter\x18\x03 \x01(\tB\x03\xe0A\x01R\x06filter:V\xeaAS\n" +
|
||||
"\x15memos.api.v1/Shortcut\x12%users/{username}/shortcuts/{shortcut}*\tshortcuts2\bshortcut\"M\n" +
|
||||
"\x14ListShortcutsRequest\x125\n" +
|
||||
"\x06parent\x18\x01 \x01(\tB\x1d\xe0A\x02\xfaA\x17\x12\x15memos.api.v1/ShortcutR\x06parent\"M\n" +
|
||||
"\x15ListShortcutsResponse\x124\n" +
|
||||
|
|
|
|||
|
|
@ -506,11 +506,7 @@ func (x *ListUsersResponse) GetTotalSize() int32 {
|
|||
type GetUserRequest struct {
|
||||
state protoimpl.MessageState `protogen:"open.v1"`
|
||||
// Required. The resource name of the user.
|
||||
// Supports both numeric IDs and username strings:
|
||||
// - users/{id} (e.g., users/101)
|
||||
// - users/{username} (e.g., users/steven)
|
||||
//
|
||||
// Format: users/{id_or_username}
|
||||
// Format: users/{username}
|
||||
Name string `protobuf:"bytes,1,opt,name=name,proto3" json:"name,omitempty"`
|
||||
// Optional. The fields to return in the response.
|
||||
// If not specified, all fields are returned.
|
||||
|
|
@ -979,8 +975,8 @@ func (x *ListAllUserStatsResponse) GetStats() []*UserStats {
|
|||
type UserSetting struct {
|
||||
state protoimpl.MessageState `protogen:"open.v1"`
|
||||
// The name of the user setting.
|
||||
// Format: users/{user}/settings/{setting}, {setting} is the key for the setting.
|
||||
// For example, "users/123/settings/GENERAL" for general settings.
|
||||
// Format: users/{username}/settings/{setting}, {setting} is the key for the setting.
|
||||
// For example, "users/steven/settings/GENERAL" for general settings.
|
||||
Name string `protobuf:"bytes,1,opt,name=name,proto3" json:"name,omitempty"`
|
||||
// Types that are valid to be assigned to Value:
|
||||
//
|
||||
|
|
@ -2658,7 +2654,7 @@ const file_api_v1_user_service_proto_rawDesc = "" +
|
|||
"\x11memos.api.v1/UserR\x04name\"\x19\n" +
|
||||
"\x17ListAllUserStatsRequest\"I\n" +
|
||||
"\x18ListAllUserStatsResponse\x12-\n" +
|
||||
"\x05stats\x18\x01 \x03(\v2\x17.memos.api.v1.UserStatsR\x05stats\"\xb0\x04\n" +
|
||||
"\x05stats\x18\x01 \x03(\v2\x17.memos.api.v1.UserStatsR\x05stats\"\xb4\x04\n" +
|
||||
"\vUserSetting\x12\x17\n" +
|
||||
"\x04name\x18\x01 \x01(\tB\x03\xe0A\bR\x04name\x12S\n" +
|
||||
"\x0fgeneral_setting\x18\x02 \x01(\v2(.memos.api.v1.UserSetting.GeneralSettingH\x00R\x0egeneralSetting\x12V\n" +
|
||||
|
|
@ -2672,8 +2668,8 @@ const file_api_v1_user_service_proto_rawDesc = "" +
|
|||
"\x03Key\x12\x13\n" +
|
||||
"\x0fKEY_UNSPECIFIED\x10\x00\x12\v\n" +
|
||||
"\aGENERAL\x10\x01\x12\f\n" +
|
||||
"\bWEBHOOKS\x10\x04:Y\xeaAV\n" +
|
||||
"\x18memos.api.v1/UserSetting\x12\x1fusers/{user}/settings/{setting}*\fuserSettings2\vuserSettingB\a\n" +
|
||||
"\bWEBHOOKS\x10\x04:]\xeaAZ\n" +
|
||||
"\x18memos.api.v1/UserSetting\x12#users/{username}/settings/{setting}*\fuserSettings2\vuserSettingB\a\n" +
|
||||
"\x05value\"M\n" +
|
||||
"\x15GetUserSettingRequest\x124\n" +
|
||||
"\x04name\x18\x01 \x01(\tB \xe0A\x02\xfaA\x1a\n" +
|
||||
|
|
|
|||
|
|
@ -48,10 +48,8 @@ const (
|
|||
type UserServiceClient interface {
|
||||
// ListUsers returns a list of users.
|
||||
ListUsers(ctx context.Context, in *ListUsersRequest, opts ...grpc.CallOption) (*ListUsersResponse, error)
|
||||
// GetUser gets a user by ID or username.
|
||||
// Supports both numeric IDs and username strings:
|
||||
// - users/{id} (e.g., users/101)
|
||||
// - users/{username} (e.g., users/steven)
|
||||
// GetUser gets a user by username.
|
||||
// Format: users/{username} (e.g., users/steven)
|
||||
GetUser(ctx context.Context, in *GetUserRequest, opts ...grpc.CallOption) (*User, error)
|
||||
// CreateUser creates a new user.
|
||||
CreateUser(ctx context.Context, in *CreateUserRequest, opts ...grpc.CallOption) (*User, error)
|
||||
|
|
@ -307,10 +305,8 @@ func (c *userServiceClient) DeleteUserNotification(ctx context.Context, in *Dele
|
|||
type UserServiceServer interface {
|
||||
// ListUsers returns a list of users.
|
||||
ListUsers(context.Context, *ListUsersRequest) (*ListUsersResponse, error)
|
||||
// GetUser gets a user by ID or username.
|
||||
// Supports both numeric IDs and username strings:
|
||||
// - users/{id} (e.g., users/101)
|
||||
// - users/{username} (e.g., users/steven)
|
||||
// GetUser gets a user by username.
|
||||
// Format: users/{username} (e.g., users/steven)
|
||||
GetUser(context.Context, *GetUserRequest) (*User, error)
|
||||
// CreateUser creates a new user.
|
||||
CreateUser(context.Context, *CreateUserRequest) (*User, error)
|
||||
|
|
|
|||
|
|
@ -1206,10 +1206,8 @@ paths:
|
|||
tags:
|
||||
- UserService
|
||||
description: |-
|
||||
GetUser gets a user by ID or username.
|
||||
Supports both numeric IDs and username strings:
|
||||
- users/{id} (e.g., users/101)
|
||||
- users/{username} (e.g., users/steven)
|
||||
GetUser gets a user by username.
|
||||
Format: users/{username} (e.g., users/steven)
|
||||
operationId: UserService_GetUser
|
||||
parameters:
|
||||
- name: user
|
||||
|
|
@ -2398,7 +2396,12 @@ components:
|
|||
backgroundColor:
|
||||
allOf:
|
||||
- $ref: '#/components/schemas/Color'
|
||||
description: Background color for the tag label.
|
||||
description: |-
|
||||
Optional background color for the tag label.
|
||||
When unset, the default tag color is used.
|
||||
blurContent:
|
||||
type: boolean
|
||||
description: Whether memos with this tag should have their content blurred.
|
||||
description: Metadata for a tag.
|
||||
InstanceSetting_TagsSetting:
|
||||
type: object
|
||||
|
|
@ -2931,7 +2934,7 @@ components:
|
|||
type: string
|
||||
description: |-
|
||||
The resource name of the shortcut.
|
||||
Format: users/{user}/shortcuts/{shortcut}
|
||||
Format: users/{username}/shortcuts/{shortcut}
|
||||
title:
|
||||
type: string
|
||||
description: The title of the shortcut.
|
||||
|
|
@ -3170,8 +3173,8 @@ components:
|
|||
type: string
|
||||
description: |-
|
||||
The name of the user setting.
|
||||
Format: users/{user}/settings/{setting}, {setting} is the key for the setting.
|
||||
For example, "users/123/settings/GENERAL" for general settings.
|
||||
Format: users/{username}/settings/{setting}, {setting} is the key for the setting.
|
||||
For example, "users/steven/settings/GENERAL" for general settings.
|
||||
generalSetting:
|
||||
$ref: '#/components/schemas/UserSetting_GeneralSetting'
|
||||
webhooksSetting:
|
||||
|
|
|
|||
|
|
@ -754,7 +754,8 @@ func (x *InstanceMemoRelatedSetting) GetReactions() []string {
|
|||
|
||||
type InstanceTagMetadata struct {
|
||||
state protoimpl.MessageState `protogen:"open.v1"`
|
||||
// Background color for the tag label.
|
||||
// Optional background color for the tag label.
|
||||
// When unset, the default tag color is used.
|
||||
BackgroundColor *color.Color `protobuf:"bytes,1,opt,name=background_color,json=backgroundColor,proto3" json:"background_color,omitempty"`
|
||||
unknownFields protoimpl.UnknownFields
|
||||
sizeCache protoimpl.SizeCache
|
||||
|
|
|
|||
|
|
@ -111,7 +111,8 @@ message InstanceMemoRelatedSetting {
|
|||
}
|
||||
|
||||
message InstanceTagMetadata {
|
||||
// Background color for the tag label.
|
||||
// Optional background color for the tag label.
|
||||
// When unset, the default tag color is used.
|
||||
google.type.Color background_color = 1;
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -360,7 +360,7 @@ func SaveAttachmentBlob(ctx context.Context, profile *profile.Profile, stores *s
|
|||
}
|
||||
|
||||
if instanceStorageSetting.StorageType == storepb.InstanceStorageSetting_LOCAL {
|
||||
filepathTemplate := "assets/{timestamp}_{filename}"
|
||||
filepathTemplate := "assets/{timestamp}_{uuid}_{filename}"
|
||||
if instanceStorageSetting.FilepathTemplate != "" {
|
||||
filepathTemplate = instanceStorageSetting.FilepathTemplate
|
||||
}
|
||||
|
|
@ -377,6 +377,15 @@ func SaveAttachmentBlob(ctx context.Context, profile *profile.Profile, stores *s
|
|||
if !filepath.IsAbs(osPath) {
|
||||
osPath = filepath.Join(profile.Data, osPath)
|
||||
}
|
||||
osPath = ensureUniqueLocalAttachmentPath(osPath, create.UID)
|
||||
internalPath = filepath.ToSlash(osPath)
|
||||
if !filepath.IsAbs(filepath.FromSlash(internalPath)) {
|
||||
internalPath, err = filepath.Rel(profile.Data, osPath)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "Failed to get relative path")
|
||||
}
|
||||
internalPath = filepath.ToSlash(internalPath)
|
||||
}
|
||||
dir := filepath.Dir(osPath)
|
||||
if err = os.MkdirAll(dir, os.ModePerm); err != nil {
|
||||
return errors.Wrap(err, "Failed to create directory")
|
||||
|
|
@ -514,6 +523,16 @@ func replaceFilenameWithPathTemplate(path, filename string) string {
|
|||
return path
|
||||
}
|
||||
|
||||
func ensureUniqueLocalAttachmentPath(path, uid string) string {
|
||||
if _, err := os.Stat(path); err != nil {
|
||||
return path
|
||||
}
|
||||
|
||||
ext := filepath.Ext(path)
|
||||
base := strings.TrimSuffix(path, ext)
|
||||
return base + "_" + uid + ext
|
||||
}
|
||||
|
||||
func validateFilename(filename string) bool {
|
||||
// Reject path traversal attempts and make sure no additional directories are created
|
||||
if !filepath.IsLocal(filename) || strings.ContainsAny(filename, "/\\") {
|
||||
|
|
|
|||
|
|
@ -48,7 +48,7 @@ func (s *APIV1Service) GetCurrentUser(ctx context.Context, _ *v1pb.GetCurrentUse
|
|||
}
|
||||
|
||||
return &v1pb.GetCurrentUserResponse{
|
||||
User: convertUserFromStore(user),
|
||||
User: convertUserFromStore(user, user),
|
||||
}, nil
|
||||
}
|
||||
|
||||
|
|
@ -187,7 +187,7 @@ func (s *APIV1Service) SignIn(ctx context.Context, request *v1pb.SignInRequest)
|
|||
}
|
||||
|
||||
return &v1pb.SignInResponse{
|
||||
User: convertUserFromStore(existingUser),
|
||||
User: convertUserFromStore(existingUser, existingUser),
|
||||
AccessToken: accessToken,
|
||||
AccessTokenExpiresAt: timestamppb.New(accessExpiresAt),
|
||||
}, nil
|
||||
|
|
|
|||
|
|
@ -49,17 +49,8 @@ func (s *APIV1Service) ListIdentityProviders(ctx context.Context, _ *v1pb.ListId
|
|||
response := &v1pb.ListIdentityProvidersResponse{
|
||||
IdentityProviders: []*v1pb.IdentityProvider{},
|
||||
}
|
||||
|
||||
// Default to lowest-privilege role, update later based on real role
|
||||
currentUserRole := store.RoleUser
|
||||
currentUser, err := s.fetchCurrentUser(ctx)
|
||||
if err == nil && currentUser != nil {
|
||||
currentUserRole = currentUser.Role
|
||||
}
|
||||
|
||||
for _, identityProvider := range identityProviders {
|
||||
identityProviderConverted := convertIdentityProviderFromStore(identityProvider)
|
||||
response.IdentityProviders = append(response.IdentityProviders, redactIdentityProviderResponse(identityProviderConverted, currentUserRole))
|
||||
response.IdentityProviders = append(response.IdentityProviders, convertIdentityProviderFromStore(identityProvider))
|
||||
}
|
||||
return response, nil
|
||||
}
|
||||
|
|
@ -79,15 +70,7 @@ func (s *APIV1Service) GetIdentityProvider(ctx context.Context, request *v1pb.Ge
|
|||
return nil, status.Errorf(codes.NotFound, "identity provider not found")
|
||||
}
|
||||
|
||||
// Default to lowest-privilege role, update later based on real role
|
||||
currentUserRole := store.RoleUser
|
||||
currentUser, err := s.fetchCurrentUser(ctx)
|
||||
if err == nil && currentUser != nil {
|
||||
currentUserRole = currentUser.Role
|
||||
}
|
||||
|
||||
identityProviderConverted := convertIdentityProviderFromStore(identityProvider)
|
||||
return redactIdentityProviderResponse(identityProviderConverted, currentUserRole), nil
|
||||
return convertIdentityProviderFromStore(identityProvider), nil
|
||||
}
|
||||
|
||||
func (s *APIV1Service) UpdateIdentityProvider(ctx context.Context, request *v1pb.UpdateIdentityProviderRequest) (*v1pb.IdentityProvider, error) {
|
||||
|
|
@ -137,6 +120,15 @@ func (s *APIV1Service) UpdateIdentityProvider(ctx context.Context, request *v1pb
|
|||
}
|
||||
}
|
||||
|
||||
// Preserve write-only credential when the caller sends an empty value.
|
||||
if update.Config != nil {
|
||||
if oauth2Config := update.Config.GetOauth2Config(); oauth2Config != nil && oauth2Config.ClientSecret == "" {
|
||||
if existingOAuth := existing.Config.GetOauth2Config(); existingOAuth != nil {
|
||||
oauth2Config.ClientSecret = existingOAuth.ClientSecret
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
identityProvider, err := s.Store.UpdateIdentityProvider(ctx, update)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to update identity provider, error: %+v", err)
|
||||
|
|
@ -188,12 +180,12 @@ func convertIdentityProviderFromStore(identityProvider *storepb.IdentityProvider
|
|||
temp.Config = &v1pb.IdentityProviderConfig{
|
||||
Config: &v1pb.IdentityProviderConfig_Oauth2Config{
|
||||
Oauth2Config: &v1pb.OAuth2Config{
|
||||
ClientId: oauth2Config.ClientId,
|
||||
ClientSecret: oauth2Config.ClientSecret,
|
||||
AuthUrl: oauth2Config.AuthUrl,
|
||||
TokenUrl: oauth2Config.TokenUrl,
|
||||
UserInfoUrl: oauth2Config.UserInfoUrl,
|
||||
Scopes: oauth2Config.Scopes,
|
||||
ClientId: oauth2Config.ClientId,
|
||||
// ClientSecret is write-only: never returned in responses.
|
||||
AuthUrl: oauth2Config.AuthUrl,
|
||||
TokenUrl: oauth2Config.TokenUrl,
|
||||
UserInfoUrl: oauth2Config.UserInfoUrl,
|
||||
Scopes: oauth2Config.Scopes,
|
||||
FieldMapping: &v1pb.FieldMapping{
|
||||
Identifier: oauth2Config.FieldMapping.Identifier,
|
||||
DisplayName: oauth2Config.FieldMapping.DisplayName,
|
||||
|
|
@ -241,13 +233,3 @@ func convertIdentityProviderConfigToStore(identityProviderType v1pb.IdentityProv
|
|||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func redactIdentityProviderResponse(identityProvider *v1pb.IdentityProvider, userRole store.Role) *v1pb.IdentityProvider {
|
||||
if userRole != store.RoleAdmin {
|
||||
if identityProvider.Type == v1pb.IdentityProvider_OAUTH2 {
|
||||
identityProvider.Config.GetOauth2Config().ClientSecret = ""
|
||||
}
|
||||
}
|
||||
|
||||
return identityProvider
|
||||
}
|
||||
|
|
|
|||
|
|
@ -71,8 +71,9 @@ func (s *APIV1Service) GetInstanceSetting(ctx context.Context, request *v1pb.Get
|
|||
return nil, status.Errorf(codes.NotFound, "instance setting not found")
|
||||
}
|
||||
|
||||
// For storage setting, only admin can get it.
|
||||
if instanceSetting.Key == storepb.InstanceSettingKey_STORAGE {
|
||||
// Storage and notification settings contain credentials; restrict to admins only.
|
||||
if instanceSetting.Key == storepb.InstanceSettingKey_STORAGE ||
|
||||
instanceSetting.Key == storepb.InstanceSettingKey_NOTIFICATION {
|
||||
user, err := s.fetchCurrentUser(ctx)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to get current user: %v", err)
|
||||
|
|
@ -108,6 +109,28 @@ func (s *APIV1Service) UpdateInstanceSetting(ctx context.Context, request *v1pb.
|
|||
}
|
||||
|
||||
updateSetting := convertInstanceSettingToStore(request.Setting)
|
||||
|
||||
// Preserve write-only credential fields when the caller sends an empty value.
|
||||
// An empty string means "no change", not "clear the credential".
|
||||
switch updateSetting.Key {
|
||||
case storepb.InstanceSettingKey_NOTIFICATION:
|
||||
if notif := updateSetting.GetNotificationSetting(); notif != nil && notif.Email != nil && notif.Email.SmtpPassword == "" {
|
||||
existing, err := s.Store.GetInstanceNotificationSetting(ctx)
|
||||
if err == nil && existing != nil && existing.Email != nil {
|
||||
notif.Email.SmtpPassword = existing.Email.SmtpPassword
|
||||
}
|
||||
}
|
||||
case storepb.InstanceSettingKey_STORAGE:
|
||||
if storage := updateSetting.GetStorageSetting(); storage != nil && storage.S3Config != nil && storage.S3Config.AccessKeySecret == "" {
|
||||
existing, err := s.Store.GetInstanceStorageSetting(ctx)
|
||||
if err == nil && existing != nil && existing.S3Config != nil {
|
||||
storage.S3Config.AccessKeySecret = existing.S3Config.AccessKeySecret
|
||||
}
|
||||
}
|
||||
default:
|
||||
// No credential preservation needed for other setting types.
|
||||
}
|
||||
|
||||
instanceSetting, err := s.Store.UpsertInstanceSetting(ctx, updateSetting)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to upsert instance setting: %v", err)
|
||||
|
|
@ -240,12 +263,12 @@ func convertInstanceStorageSettingFromStore(settingpb *storepb.InstanceStorageSe
|
|||
}
|
||||
if settingpb.S3Config != nil {
|
||||
setting.S3Config = &v1pb.InstanceSetting_StorageSetting_S3Config{
|
||||
AccessKeyId: settingpb.S3Config.AccessKeyId,
|
||||
AccessKeySecret: settingpb.S3Config.AccessKeySecret,
|
||||
Endpoint: settingpb.S3Config.Endpoint,
|
||||
Region: settingpb.S3Config.Region,
|
||||
Bucket: settingpb.S3Config.Bucket,
|
||||
UsePathStyle: settingpb.S3Config.UsePathStyle,
|
||||
AccessKeyId: settingpb.S3Config.AccessKeyId,
|
||||
// AccessKeySecret is write-only: never returned in responses.
|
||||
Endpoint: settingpb.S3Config.Endpoint,
|
||||
Region: settingpb.S3Config.Region,
|
||||
Bucket: settingpb.S3Config.Bucket,
|
||||
UsePathStyle: settingpb.S3Config.UsePathStyle,
|
||||
}
|
||||
}
|
||||
return setting
|
||||
|
|
@ -339,12 +362,12 @@ func convertInstanceNotificationSettingFromStore(setting *storepb.InstanceNotifi
|
|||
SmtpHost: setting.Email.SmtpHost,
|
||||
SmtpPort: setting.Email.SmtpPort,
|
||||
SmtpUsername: setting.Email.SmtpUsername,
|
||||
SmtpPassword: setting.Email.SmtpPassword,
|
||||
FromEmail: setting.Email.FromEmail,
|
||||
FromName: setting.Email.FromName,
|
||||
ReplyTo: setting.Email.ReplyTo,
|
||||
UseTls: setting.Email.UseTls,
|
||||
UseSsl: setting.Email.UseSsl,
|
||||
// SmtpPassword is write-only: never returned in responses.
|
||||
FromEmail: setting.Email.FromEmail,
|
||||
FromName: setting.Email.FromName,
|
||||
ReplyTo: setting.Email.ReplyTo,
|
||||
UseTls: setting.Email.UseTls,
|
||||
UseSsl: setting.Email.UseSsl,
|
||||
}
|
||||
}
|
||||
return notificationSetting
|
||||
|
|
@ -398,11 +421,10 @@ func validateInstanceTagsSetting(setting *v1pb.InstanceSetting_TagsSetting) erro
|
|||
if metadata == nil {
|
||||
return errors.Errorf("tag metadata is required for %q", tag)
|
||||
}
|
||||
if metadata.GetBackgroundColor() == nil {
|
||||
return errors.Errorf("background_color is required for %q", tag)
|
||||
}
|
||||
if err := validateInstanceColor(metadata.GetBackgroundColor()); err != nil {
|
||||
return errors.Wrapf(err, "background_color for %q", tag)
|
||||
if metadata.GetBackgroundColor() != nil {
|
||||
if err := validateInstanceColor(metadata.GetBackgroundColor()); err != nil {
|
||||
return errors.Wrapf(err, "background_color for %q", tag)
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
|
|
@ -448,5 +470,6 @@ func (s *APIV1Service) GetInstanceAdmin(ctx context.Context) (*v1pb.User, error)
|
|||
return nil, nil
|
||||
}
|
||||
|
||||
return convertUserFromStore(user), nil
|
||||
currentUser, _ := s.fetchCurrentUser(ctx)
|
||||
return convertUserFromStore(user, currentUser), nil
|
||||
}
|
||||
|
|
|
|||
|
|
@ -35,20 +35,36 @@ func (s *APIV1Service) SetMemoAttachments(ctx context.Context, request *v1pb.Set
|
|||
if memo.CreatorID != user.ID && !isSuperUser(user) {
|
||||
return nil, status.Errorf(codes.PermissionDenied, "permission denied")
|
||||
}
|
||||
if err := s.setMemoAttachmentsInternal(ctx, memo, request.Attachments); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := s.touchMemoUpdatedTimestamp(ctx, memo.ID); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
updatedMemo, parentMemo, memoMessage, err := s.buildUpdatedMemoState(ctx, memo.ID)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to build updated memo state")
|
||||
}
|
||||
s.dispatchMemoUpdatedSideEffects(ctx, updatedMemo, parentMemo, memoMessage)
|
||||
|
||||
return &emptypb.Empty{}, nil
|
||||
}
|
||||
|
||||
func (s *APIV1Service) setMemoAttachmentsInternal(ctx context.Context, memo *store.Memo, requestAttachments []*v1pb.Attachment) error {
|
||||
attachments, err := s.Store.ListAttachments(ctx, &store.FindAttachment{
|
||||
MemoID: &memo.ID,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to list attachments")
|
||||
return status.Errorf(codes.Internal, "failed to list attachments")
|
||||
}
|
||||
|
||||
// Delete attachments that are not in the request.
|
||||
for _, attachment := range attachments {
|
||||
found := false
|
||||
for _, requestAttachment := range request.Attachments {
|
||||
for _, requestAttachment := range requestAttachments {
|
||||
requestAttachmentUID, err := ExtractAttachmentUIDFromName(requestAttachment.Name)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.InvalidArgument, "invalid attachment name: %v", err)
|
||||
return status.Errorf(codes.InvalidArgument, "invalid attachment name: %v", err)
|
||||
}
|
||||
if attachment.UID == requestAttachmentUID {
|
||||
found = true
|
||||
|
|
@ -60,24 +76,24 @@ func (s *APIV1Service) SetMemoAttachments(ctx context.Context, request *v1pb.Set
|
|||
ID: int32(attachment.ID),
|
||||
MemoID: &memo.ID,
|
||||
}); err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to delete attachment")
|
||||
return status.Errorf(codes.Internal, "failed to delete attachment")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
slices.Reverse(request.Attachments)
|
||||
slices.Reverse(requestAttachments)
|
||||
// Update attachments' memo_id in the request.
|
||||
for index, attachment := range request.Attachments {
|
||||
for index, attachment := range requestAttachments {
|
||||
attachmentUID, err := ExtractAttachmentUIDFromName(attachment.Name)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.InvalidArgument, "invalid attachment name: %v", err)
|
||||
return status.Errorf(codes.InvalidArgument, "invalid attachment name: %v", err)
|
||||
}
|
||||
tempAttachment, err := s.Store.GetAttachment(ctx, &store.FindAttachment{UID: &attachmentUID})
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to get attachment: %v", err)
|
||||
return status.Errorf(codes.Internal, "failed to get attachment: %v", err)
|
||||
}
|
||||
if tempAttachment == nil {
|
||||
return nil, status.Errorf(codes.NotFound, "attachment not found: %s", attachmentUID)
|
||||
return status.Errorf(codes.NotFound, "attachment not found: %s", attachmentUID)
|
||||
}
|
||||
updatedTs := time.Now().Unix() + int64(index)
|
||||
if err := s.Store.UpdateAttachment(ctx, &store.UpdateAttachment{
|
||||
|
|
@ -85,11 +101,11 @@ func (s *APIV1Service) SetMemoAttachments(ctx context.Context, request *v1pb.Set
|
|||
MemoID: &memo.ID,
|
||||
UpdatedTs: &updatedTs,
|
||||
}); err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to update attachment: %v", err)
|
||||
return status.Errorf(codes.Internal, "failed to update attachment: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
return &emptypb.Empty{}, nil
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *APIV1Service) ListMemoAttachments(ctx context.Context, request *v1pb.ListMemoAttachmentsRequest) (*v1pb.ListMemoAttachmentsResponse, error) {
|
||||
|
|
|
|||
|
|
@ -35,18 +35,34 @@ func (s *APIV1Service) SetMemoRelations(ctx context.Context, request *v1pb.SetMe
|
|||
if memo.CreatorID != user.ID && !isSuperUser(user) {
|
||||
return nil, status.Errorf(codes.PermissionDenied, "permission denied")
|
||||
}
|
||||
if err := s.setMemoRelationsInternal(ctx, memo, request.Relations); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := s.touchMemoUpdatedTimestamp(ctx, memo.ID); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
updatedMemo, parentMemo, memoMessage, err := s.buildUpdatedMemoState(ctx, memo.ID)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to build updated memo state")
|
||||
}
|
||||
s.dispatchMemoUpdatedSideEffects(ctx, updatedMemo, parentMemo, memoMessage)
|
||||
|
||||
return &emptypb.Empty{}, nil
|
||||
}
|
||||
|
||||
func (s *APIV1Service) setMemoRelationsInternal(ctx context.Context, memo *store.Memo, relations []*v1pb.MemoRelation) error {
|
||||
referenceType := store.MemoRelationReference
|
||||
// Delete all reference relations first.
|
||||
if err := s.Store.DeleteMemoRelation(ctx, &store.DeleteMemoRelation{
|
||||
MemoID: &memo.ID,
|
||||
Type: &referenceType,
|
||||
}); err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to delete memo relation")
|
||||
return status.Errorf(codes.Internal, "failed to delete memo relation")
|
||||
}
|
||||
|
||||
for _, relation := range request.Relations {
|
||||
for _, relation := range relations {
|
||||
// Ignore reflexive relations.
|
||||
if request.Name == relation.RelatedMemo.Name {
|
||||
if buildMemoName(memo.UID) == relation.RelatedMemo.Name {
|
||||
continue
|
||||
}
|
||||
// Ignore comment relations as there's no need to update a comment's relation.
|
||||
|
|
@ -56,22 +72,22 @@ func (s *APIV1Service) SetMemoRelations(ctx context.Context, request *v1pb.SetMe
|
|||
}
|
||||
relatedMemoUID, err := ExtractMemoUIDFromName(relation.RelatedMemo.Name)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.InvalidArgument, "invalid related memo name: %v", err)
|
||||
return status.Errorf(codes.InvalidArgument, "invalid related memo name: %v", err)
|
||||
}
|
||||
relatedMemo, err := s.Store.GetMemo(ctx, &store.FindMemo{UID: &relatedMemoUID})
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to get related memo")
|
||||
return status.Errorf(codes.Internal, "failed to get related memo")
|
||||
}
|
||||
if _, err := s.Store.UpsertMemoRelation(ctx, &store.MemoRelation{
|
||||
MemoID: memo.ID,
|
||||
RelatedMemoID: relatedMemo.ID,
|
||||
Type: convertMemoRelationTypeToStore(relation.Type),
|
||||
}); err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to upsert memo relation")
|
||||
return status.Errorf(codes.Internal, "failed to upsert memo relation")
|
||||
}
|
||||
}
|
||||
|
||||
return &emptypb.Empty{}, nil
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *APIV1Service) ListMemoRelations(ctx context.Context, request *v1pb.ListMemoRelationsRequest) (*v1pb.ListMemoRelationsResponse, error) {
|
||||
|
|
|
|||
|
|
@ -19,6 +19,19 @@ import (
|
|||
"github.com/usememos/memos/store"
|
||||
)
|
||||
|
||||
// suppressSSEKey is a context key used to suppress the SSE broadcast from
|
||||
// CreateMemo when it is called internally (e.g., from CreateMemoComment).
|
||||
type suppressSSEKey struct{}
|
||||
|
||||
func withSuppressSSE(ctx context.Context) context.Context {
|
||||
return context.WithValue(ctx, suppressSSEKey{}, true)
|
||||
}
|
||||
|
||||
func isSSESuppressed(ctx context.Context) bool {
|
||||
v, ok := ctx.Value(suppressSSEKey{}).(bool)
|
||||
return ok && v
|
||||
}
|
||||
|
||||
func (s *APIV1Service) CreateMemo(ctx context.Context, request *v1pb.CreateMemoRequest) (*v1pb.Memo, error) {
|
||||
user, err := s.fetchCurrentUser(ctx)
|
||||
if err != nil {
|
||||
|
|
@ -136,11 +149,15 @@ func (s *APIV1Service) CreateMemo(ctx context.Context, request *v1pb.CreateMemoR
|
|||
slog.Warn("Failed to dispatch memo created webhook", slog.Any("err", err))
|
||||
}
|
||||
|
||||
// Broadcast live refresh event.
|
||||
s.SSEHub.Broadcast(&SSEEvent{
|
||||
Type: SSEEventMemoCreated,
|
||||
Name: memoMessage.Name,
|
||||
})
|
||||
// Broadcast live refresh event (skipped when called from CreateMemoComment).
|
||||
if !isSSESuppressed(ctx) {
|
||||
s.SSEHub.Broadcast(&SSEEvent{
|
||||
Type: SSEEventMemoCreated,
|
||||
Name: memoMessage.Name,
|
||||
Visibility: memo.Visibility,
|
||||
CreatorID: resolveSSECreatorID(memo, nil),
|
||||
})
|
||||
}
|
||||
|
||||
return memoMessage, nil
|
||||
}
|
||||
|
|
@ -278,6 +295,14 @@ func (s *APIV1Service) ListMemos(ctx context.Context, request *v1pb.ListMemosReq
|
|||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to batch load memo relations")
|
||||
}
|
||||
creatorIDs := make([]int32, 0, len(memos))
|
||||
for _, memo := range memos {
|
||||
creatorIDs = append(creatorIDs, memo.CreatorID)
|
||||
}
|
||||
creatorMap, err := s.listUsersByID(ctx, creatorIDs)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to list memo creators: %v", err)
|
||||
}
|
||||
|
||||
for _, memo := range memos {
|
||||
memoName := fmt.Sprintf("%s%s", MemoNamePrefix, memo.UID)
|
||||
|
|
@ -285,7 +310,7 @@ func (s *APIV1Service) ListMemos(ctx context.Context, request *v1pb.ListMemosReq
|
|||
attachments := attachmentMap[memo.ID]
|
||||
relations := relationMap[memo.ID]
|
||||
|
||||
memoMessage, err := s.convertMemoFromStore(ctx, memo, reactions, attachments, relations)
|
||||
memoMessage, err := s.convertMemoFromStoreWithCreators(ctx, memo, reactions, attachments, relations, creatorMap)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "failed to convert memo")
|
||||
}
|
||||
|
|
@ -444,19 +469,11 @@ func (s *APIV1Service) UpdateMemo(ctx context.Context, request *v1pb.UpdateMemoR
|
|||
payload.Location = convertLocationToStore(request.Memo.Location)
|
||||
update.Payload = payload
|
||||
} else if path == "attachments" {
|
||||
_, err := s.SetMemoAttachments(ctx, &v1pb.SetMemoAttachmentsRequest{
|
||||
Name: request.Memo.Name,
|
||||
Attachments: request.Memo.Attachments,
|
||||
})
|
||||
if err != nil {
|
||||
if err := s.setMemoAttachmentsInternal(ctx, memo, request.Memo.Attachments); err != nil {
|
||||
return nil, errors.Wrap(err, "failed to set memo attachments")
|
||||
}
|
||||
} else if path == "relations" {
|
||||
_, err := s.SetMemoRelations(ctx, &v1pb.SetMemoRelationsRequest{
|
||||
Name: request.Memo.Name,
|
||||
Relations: request.Memo.Relations,
|
||||
})
|
||||
if err != nil {
|
||||
if err := s.setMemoRelationsInternal(ctx, memo, request.Memo.Relations); err != nil {
|
||||
return nil, errors.Wrap(err, "failed to set memo relations")
|
||||
}
|
||||
}
|
||||
|
|
@ -472,37 +489,11 @@ func (s *APIV1Service) UpdateMemo(ctx context.Context, request *v1pb.UpdateMemoR
|
|||
if err != nil {
|
||||
return nil, errors.Wrap(err, "failed to get memo")
|
||||
}
|
||||
reactions, err := s.Store.ListReactions(ctx, &store.FindReaction{
|
||||
ContentID: &request.Memo.Name,
|
||||
})
|
||||
memo, parentMemo, memoMessage, err := s.buildUpdatedMemoState(ctx, memo.ID)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to list reactions")
|
||||
return nil, errors.Wrap(err, "failed to build updated memo state")
|
||||
}
|
||||
attachments, err := s.Store.ListAttachments(ctx, &store.FindAttachment{
|
||||
MemoID: &memo.ID,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to list attachments")
|
||||
}
|
||||
|
||||
relations, err := s.loadMemoRelations(ctx, memo)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "failed to load memo relations")
|
||||
}
|
||||
memoMessage, err := s.convertMemoFromStore(ctx, memo, reactions, attachments, relations)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "failed to convert memo")
|
||||
}
|
||||
// Try to dispatch webhook when memo is updated.
|
||||
if err := s.DispatchMemoUpdatedWebhook(ctx, memoMessage); err != nil {
|
||||
slog.Warn("Failed to dispatch memo updated webhook", slog.Any("err", err))
|
||||
}
|
||||
|
||||
// Broadcast live refresh event.
|
||||
s.SSEHub.Broadcast(&SSEEvent{
|
||||
Type: SSEEventMemoUpdated,
|
||||
Name: memoMessage.Name,
|
||||
})
|
||||
s.dispatchMemoUpdatedSideEffects(ctx, memo, parentMemo, memoMessage)
|
||||
|
||||
return memoMessage, nil
|
||||
}
|
||||
|
|
@ -575,8 +566,10 @@ func (s *APIV1Service) DeleteMemo(ctx context.Context, request *v1pb.DeleteMemoR
|
|||
|
||||
// Broadcast live refresh event.
|
||||
s.SSEHub.Broadcast(&SSEEvent{
|
||||
Type: SSEEventMemoDeleted,
|
||||
Name: request.Name,
|
||||
Type: SSEEventMemoDeleted,
|
||||
Name: request.Name,
|
||||
Visibility: memo.Visibility,
|
||||
CreatorID: resolveSSECreatorID(memo, nil),
|
||||
})
|
||||
|
||||
return &emptypb.Empty{}, nil
|
||||
|
|
@ -607,8 +600,9 @@ func (s *APIV1Service) CreateMemoComment(ctx context.Context, request *v1pb.Crea
|
|||
return nil, status.Errorf(codes.PermissionDenied, "permission denied")
|
||||
}
|
||||
|
||||
// Create the memo comment first.
|
||||
memoComment, err := s.CreateMemo(ctx, &v1pb.CreateMemoRequest{
|
||||
// Create the memo comment first; suppress the generic memo.created SSE event
|
||||
// since CreateMemoComment broadcasts memo.comment.created for the parent instead.
|
||||
memoComment, err := s.CreateMemo(withSuppressSSE(ctx), &v1pb.CreateMemoRequest{
|
||||
Memo: request.Comment,
|
||||
MemoId: request.CommentId,
|
||||
})
|
||||
|
|
@ -633,10 +627,14 @@ func (s *APIV1Service) CreateMemoComment(ctx context.Context, request *v1pb.Crea
|
|||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to create memo relation")
|
||||
}
|
||||
creatorID, err := ExtractUserIDFromName(memoComment.Creator)
|
||||
creator, err := ResolveUserByName(ctx, s.Store, memoComment.Creator)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.InvalidArgument, "invalid memo creator")
|
||||
}
|
||||
if creator == nil {
|
||||
return nil, status.Errorf(codes.NotFound, "memo creator not found")
|
||||
}
|
||||
creatorID := creator.ID
|
||||
if memoComment.Visibility != v1pb.Visibility_PRIVATE && creatorID != relatedMemo.CreatorID {
|
||||
if _, err := s.Store.CreateInbox(ctx, &store.Inbox{
|
||||
SenderID: creatorID,
|
||||
|
|
@ -662,8 +660,10 @@ func (s *APIV1Service) CreateMemoComment(ctx context.Context, request *v1pb.Crea
|
|||
|
||||
// Broadcast live refresh event for the parent memo so subscribers see the new comment.
|
||||
s.SSEHub.Broadcast(&SSEEvent{
|
||||
Type: SSEEventMemoCommentCreated,
|
||||
Name: request.Name,
|
||||
Type: SSEEventMemoCommentCreated,
|
||||
Name: request.Name,
|
||||
Visibility: relatedMemo.Visibility,
|
||||
CreatorID: relatedMemo.CreatorID,
|
||||
})
|
||||
|
||||
return memoComment, nil
|
||||
|
|
@ -749,6 +749,14 @@ func (s *APIV1Service) ListMemoComments(ctx context.Context, request *v1pb.ListM
|
|||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to batch load memo relations")
|
||||
}
|
||||
creatorIDs := make([]int32, 0, len(memos))
|
||||
for _, memo := range memos {
|
||||
creatorIDs = append(creatorIDs, memo.CreatorID)
|
||||
}
|
||||
creatorMap, err := s.listUsersByID(ctx, creatorIDs)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to list memo creators: %v", err)
|
||||
}
|
||||
|
||||
var memosResponse []*v1pb.Memo
|
||||
for _, m := range memos {
|
||||
|
|
@ -757,7 +765,7 @@ func (s *APIV1Service) ListMemoComments(ctx context.Context, request *v1pb.ListM
|
|||
attachments := attachmentMap[m.ID]
|
||||
relations := relationMap[m.ID]
|
||||
|
||||
memoMessage, err := s.convertMemoFromStore(ctx, m, reactions, attachments, relations)
|
||||
memoMessage, err := s.convertMemoFromStoreWithCreators(ctx, m, reactions, attachments, relations, creatorMap)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "failed to convert memo")
|
||||
}
|
||||
|
|
@ -812,10 +820,14 @@ func (s *APIV1Service) DispatchMemoCommentCreatedWebhook(ctx context.Context, co
|
|||
}
|
||||
|
||||
func (s *APIV1Service) dispatchMemoRelatedWebhook(ctx context.Context, memo *v1pb.Memo, activityType string) error {
|
||||
creatorID, err := ExtractUserIDFromName(memo.Creator)
|
||||
creator, err := ResolveUserByName(ctx, s.Store, memo.Creator)
|
||||
if err != nil {
|
||||
return status.Errorf(codes.InvalidArgument, "invalid memo creator")
|
||||
}
|
||||
if creator == nil {
|
||||
return status.Errorf(codes.NotFound, "memo creator not found")
|
||||
}
|
||||
creatorID := creator.ID
|
||||
webhooks, err := s.Store.GetUserWebhooks(ctx, creatorID)
|
||||
if err != nil {
|
||||
return err
|
||||
|
|
@ -835,12 +847,8 @@ func (s *APIV1Service) dispatchMemoRelatedWebhook(ctx context.Context, memo *v1p
|
|||
}
|
||||
|
||||
func convertMemoToWebhookPayload(memo *v1pb.Memo) (*webhook.WebhookRequestPayload, error) {
|
||||
creatorID, err := ExtractUserIDFromName(memo.Creator)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "invalid memo creator")
|
||||
}
|
||||
return &webhook.WebhookRequestPayload{
|
||||
Creator: fmt.Sprintf("%s%d", UserNamePrefix, creatorID),
|
||||
Creator: memo.Creator,
|
||||
Memo: memo,
|
||||
}, nil
|
||||
}
|
||||
|
|
|
|||
|
|
@ -14,6 +14,14 @@ import (
|
|||
)
|
||||
|
||||
func (s *APIV1Service) convertMemoFromStore(ctx context.Context, memo *store.Memo, reactions []*store.Reaction, attachments []*store.Attachment, relations []*v1pb.MemoRelation) (*v1pb.Memo, error) {
|
||||
creatorMap, err := s.listUsersByID(ctx, []int32{memo.CreatorID})
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "failed to list memo creators")
|
||||
}
|
||||
return s.convertMemoFromStoreWithCreators(ctx, memo, reactions, attachments, relations, creatorMap)
|
||||
}
|
||||
|
||||
func (s *APIV1Service) convertMemoFromStoreWithCreators(ctx context.Context, memo *store.Memo, reactions []*store.Reaction, attachments []*store.Attachment, relations []*v1pb.MemoRelation, creatorMap map[int32]*store.User) (*v1pb.Memo, error) {
|
||||
displayTs := memo.CreatedTs
|
||||
instanceMemoRelatedSetting, err := s.Store.GetInstanceMemoRelatedSetting(ctx)
|
||||
if err != nil {
|
||||
|
|
@ -24,10 +32,14 @@ func (s *APIV1Service) convertMemoFromStore(ctx context.Context, memo *store.Mem
|
|||
}
|
||||
|
||||
name := fmt.Sprintf("%s%s", MemoNamePrefix, memo.UID)
|
||||
creator := creatorMap[memo.CreatorID]
|
||||
if creator == nil {
|
||||
return nil, errors.New("memo creator not found")
|
||||
}
|
||||
memoMessage := &v1pb.Memo{
|
||||
Name: name,
|
||||
State: convertStateFromStore(memo.RowStatus),
|
||||
Creator: fmt.Sprintf("%s%d", UserNamePrefix, memo.CreatorID),
|
||||
Creator: BuildUserName(creator.Username),
|
||||
CreateTime: timestamppb.New(time.Unix(memo.CreatedTs, 0)),
|
||||
UpdateTime: timestamppb.New(time.Unix(memo.UpdatedTs, 0)),
|
||||
DisplayTime: timestamppb.New(time.Unix(displayTs, 0)),
|
||||
|
|
@ -48,7 +60,10 @@ func (s *APIV1Service) convertMemoFromStore(ctx context.Context, memo *store.Mem
|
|||
|
||||
memoMessage.Reactions = []*v1pb.Reaction{}
|
||||
for _, reaction := range reactions {
|
||||
reactionResponse := convertReactionFromStore(reaction)
|
||||
reactionResponse, err := s.convertReactionFromStore(ctx, reaction)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "failed to convert reaction")
|
||||
}
|
||||
memoMessage.Reactions = append(memoMessage.Reactions, reactionResponse)
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -0,0 +1,78 @@
|
|||
package v1
|
||||
|
||||
import (
|
||||
"context"
|
||||
"log/slog"
|
||||
"time"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/status"
|
||||
|
||||
v1pb "github.com/usememos/memos/proto/gen/api/v1"
|
||||
"github.com/usememos/memos/store"
|
||||
)
|
||||
|
||||
func (s *APIV1Service) touchMemoUpdatedTimestamp(ctx context.Context, memoID int32) error {
|
||||
updatedTs := time.Now().Unix()
|
||||
if err := s.Store.UpdateMemo(ctx, &store.UpdateMemo{
|
||||
ID: memoID,
|
||||
UpdatedTs: &updatedTs,
|
||||
}); err != nil {
|
||||
return status.Errorf(codes.Internal, "failed to update memo timestamp")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *APIV1Service) buildUpdatedMemoState(ctx context.Context, memoID int32) (*store.Memo, *store.Memo, *v1pb.Memo, error) {
|
||||
memo, err := s.Store.GetMemo(ctx, &store.FindMemo{ID: &memoID})
|
||||
if err != nil {
|
||||
return nil, nil, nil, errors.Wrap(err, "failed to get memo")
|
||||
}
|
||||
if memo == nil {
|
||||
return nil, nil, nil, errors.New("memo not found")
|
||||
}
|
||||
|
||||
memoName := buildMemoName(memo.UID)
|
||||
reactions, err := s.Store.ListReactions(ctx, &store.FindReaction{
|
||||
ContentID: &memoName,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, nil, nil, errors.Wrap(err, "failed to list reactions")
|
||||
}
|
||||
attachments, err := s.Store.ListAttachments(ctx, &store.FindAttachment{
|
||||
MemoID: &memo.ID,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, nil, nil, errors.Wrap(err, "failed to list attachments")
|
||||
}
|
||||
relations, err := s.loadMemoRelations(ctx, memo)
|
||||
if err != nil {
|
||||
return nil, nil, nil, errors.Wrap(err, "failed to load memo relations")
|
||||
}
|
||||
memoMessage, err := s.convertMemoFromStore(ctx, memo, reactions, attachments, relations)
|
||||
if err != nil {
|
||||
return nil, nil, nil, errors.Wrap(err, "failed to convert memo")
|
||||
}
|
||||
|
||||
var parentMemo *store.Memo
|
||||
if memo.ParentUID != nil {
|
||||
parentMemo, _ = s.Store.GetMemo(ctx, &store.FindMemo{UID: memo.ParentUID})
|
||||
}
|
||||
|
||||
return memo, parentMemo, memoMessage, nil
|
||||
}
|
||||
|
||||
func (s *APIV1Service) dispatchMemoUpdatedSideEffects(ctx context.Context, memo *store.Memo, parentMemo *store.Memo, memoMessage *v1pb.Memo) {
|
||||
if err := s.DispatchMemoUpdatedWebhook(ctx, memoMessage); err != nil {
|
||||
slog.Warn("Failed to dispatch memo updated webhook", slog.Any("err", err))
|
||||
}
|
||||
|
||||
s.SSEHub.Broadcast(&SSEEvent{
|
||||
Type: SSEEventMemoUpdated,
|
||||
Name: memoMessage.Name,
|
||||
Parent: memoMessage.GetParent(),
|
||||
Visibility: memo.Visibility,
|
||||
CreatorID: resolveSSECreatorID(memo, parentMemo),
|
||||
})
|
||||
}
|
||||
|
|
@ -53,7 +53,10 @@ func (s *APIV1Service) ListMemoReactions(ctx context.Context, request *v1pb.List
|
|||
Reactions: []*v1pb.Reaction{},
|
||||
}
|
||||
for _, reaction := range reactions {
|
||||
reactionMessage := convertReactionFromStore(reaction)
|
||||
reactionMessage, err := s.convertReactionFromStore(ctx, reaction)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to convert reaction")
|
||||
}
|
||||
response.Reactions = append(response.Reactions, reactionMessage)
|
||||
}
|
||||
return response, nil
|
||||
|
|
@ -95,13 +98,17 @@ func (s *APIV1Service) UpsertMemoReaction(ctx context.Context, request *v1pb.Ups
|
|||
return nil, status.Errorf(codes.Internal, "failed to upsert reaction")
|
||||
}
|
||||
|
||||
reactionMessage := convertReactionFromStore(reaction)
|
||||
reactionMessage, err := s.convertReactionFromStore(ctx, reaction)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to convert reaction")
|
||||
}
|
||||
|
||||
// Broadcast live refresh event (reaction belongs to a memo).
|
||||
s.SSEHub.Broadcast(&SSEEvent{
|
||||
Type: SSEEventReactionUpserted,
|
||||
Name: request.Reaction.ContentId,
|
||||
})
|
||||
var parentMemo *store.Memo
|
||||
if memo.ParentUID != nil {
|
||||
parentMemo, _ = s.Store.GetMemo(ctx, &store.FindMemo{UID: memo.ParentUID})
|
||||
}
|
||||
s.SSEHub.Broadcast(buildMemoReactionSSEEvent(SSEEventReactionUpserted, request.Reaction.ContentId, memo, parentMemo))
|
||||
|
||||
return reactionMessage, nil
|
||||
}
|
||||
|
|
@ -142,24 +149,41 @@ func (s *APIV1Service) DeleteMemoReaction(ctx context.Context, request *v1pb.Del
|
|||
return nil, status.Errorf(codes.Internal, "failed to delete reaction")
|
||||
}
|
||||
|
||||
memoUID, err := ExtractMemoUIDFromName(reaction.ContentID)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.InvalidArgument, "invalid memo name: %v", err)
|
||||
}
|
||||
memo, err := s.Store.GetMemo(ctx, &store.FindMemo{UID: &memoUID})
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to get memo")
|
||||
}
|
||||
|
||||
// Broadcast live refresh event (reaction belongs to a memo).
|
||||
s.SSEHub.Broadcast(&SSEEvent{
|
||||
Type: SSEEventReactionDeleted,
|
||||
Name: reaction.ContentID,
|
||||
})
|
||||
var parentMemo *store.Memo
|
||||
if memo != nil && memo.ParentUID != nil {
|
||||
parentMemo, _ = s.Store.GetMemo(ctx, &store.FindMemo{UID: memo.ParentUID})
|
||||
}
|
||||
s.SSEHub.Broadcast(buildMemoReactionSSEEvent(SSEEventReactionDeleted, reaction.ContentID, memo, parentMemo))
|
||||
|
||||
return &emptypb.Empty{}, nil
|
||||
}
|
||||
|
||||
func convertReactionFromStore(reaction *store.Reaction) *v1pb.Reaction {
|
||||
func (s *APIV1Service) convertReactionFromStore(ctx context.Context, reaction *store.Reaction) (*v1pb.Reaction, error) {
|
||||
reactionUID := fmt.Sprintf("%d", reaction.ID)
|
||||
creator, err := s.Store.GetUser(ctx, &store.FindUser{ID: &reaction.CreatorID})
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to get reaction creator")
|
||||
}
|
||||
if creator == nil {
|
||||
return nil, status.Errorf(codes.NotFound, "reaction creator not found")
|
||||
}
|
||||
// Generate nested resource name: memos/{memo}/reactions/{reaction}
|
||||
// reaction.ContentID already contains "memos/{memo}"
|
||||
return &v1pb.Reaction{
|
||||
Name: fmt.Sprintf("%s/%s%s", reaction.ContentID, ReactionNamePrefix, reactionUID),
|
||||
Creator: fmt.Sprintf("%s%d", UserNamePrefix, reaction.CreatorID),
|
||||
Creator: BuildUserName(creator.Username),
|
||||
ContentId: reaction.ContentID,
|
||||
ReactionType: reaction.ReactionType,
|
||||
CreateTime: timestamppb.New(time.Unix(reaction.CreatedTs, 0)),
|
||||
}
|
||||
}, nil
|
||||
}
|
||||
|
|
|
|||
|
|
@ -77,17 +77,6 @@ func ExtractUserIDFromName(name string) (int32, error) {
|
|||
return id, nil
|
||||
}
|
||||
|
||||
// extractUserIdentifierFromName extracts the identifier (ID or username) from a user resource name.
|
||||
// Supports: "users/101" or "users/steven"
|
||||
// Returns the identifier string (e.g., "101" or "steven").
|
||||
func extractUserIdentifierFromName(name string) string {
|
||||
tokens, err := GetNameParentTokens(name, UserNamePrefix)
|
||||
if err != nil || len(tokens) == 0 {
|
||||
return ""
|
||||
}
|
||||
return tokens[0]
|
||||
}
|
||||
|
||||
// ExtractMemoUIDFromName returns the memo UID from a resource name.
|
||||
// e.g., "memos/uuid" -> "uuid".
|
||||
func ExtractMemoUIDFromName(name string) (string, error) {
|
||||
|
|
|
|||
|
|
@ -17,37 +17,44 @@ import (
|
|||
"github.com/usememos/memos/store"
|
||||
)
|
||||
|
||||
// Helper function to extract user ID and shortcut ID from shortcut resource name.
|
||||
// Helper function to extract user and shortcut ID from shortcut resource name.
|
||||
// Format: users/{user}/shortcuts/{shortcut}.
|
||||
func extractUserAndShortcutIDFromName(name string) (int32, string, error) {
|
||||
func (s *APIV1Service) extractUserAndShortcutIDFromName(ctx context.Context, name string) (*store.User, string, error) {
|
||||
parts := strings.Split(name, "/")
|
||||
if len(parts) != 4 || parts[0] != "users" || parts[2] != "shortcuts" {
|
||||
return 0, "", errors.Errorf("invalid shortcut name format: %s", name)
|
||||
return nil, "", errors.Errorf("invalid shortcut name format: %s", name)
|
||||
}
|
||||
|
||||
userID, err := util.ConvertStringToInt32(parts[1])
|
||||
user, err := ResolveUserByName(ctx, s.Store, BuildUserName(parts[1]))
|
||||
if err != nil {
|
||||
return 0, "", errors.Errorf("invalid user ID %q", parts[1])
|
||||
return nil, "", err
|
||||
}
|
||||
if user == nil {
|
||||
return nil, "", errors.Errorf("user not found: %s", parts[1])
|
||||
}
|
||||
|
||||
shortcutID := parts[3]
|
||||
if shortcutID == "" {
|
||||
return 0, "", errors.Errorf("empty shortcut ID in name: %s", name)
|
||||
return nil, "", errors.Errorf("empty shortcut ID in name: %s", name)
|
||||
}
|
||||
|
||||
return userID, shortcutID, nil
|
||||
return user, shortcutID, nil
|
||||
}
|
||||
|
||||
// Helper function to construct shortcut resource name.
|
||||
func constructShortcutName(userID int32, shortcutID string) string {
|
||||
return fmt.Sprintf("users/%d/shortcuts/%s", userID, shortcutID)
|
||||
func constructShortcutName(username string, shortcutID string) string {
|
||||
return fmt.Sprintf("%s/shortcuts/%s", BuildUserName(username), shortcutID)
|
||||
}
|
||||
|
||||
func (s *APIV1Service) ListShortcuts(ctx context.Context, request *v1pb.ListShortcutsRequest) (*v1pb.ListShortcutsResponse, error) {
|
||||
userID, err := ExtractUserIDFromName(request.Parent)
|
||||
user, err := ResolveUserByName(ctx, s.Store, request.Parent)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.InvalidArgument, "invalid user name: %v", err)
|
||||
}
|
||||
if user == nil {
|
||||
return nil, status.Errorf(codes.NotFound, "user not found")
|
||||
}
|
||||
userID := user.ID
|
||||
|
||||
currentUser, err := s.fetchCurrentUser(ctx)
|
||||
if err != nil {
|
||||
|
|
@ -74,7 +81,7 @@ func (s *APIV1Service) ListShortcuts(ctx context.Context, request *v1pb.ListShor
|
|||
shortcuts := []*v1pb.Shortcut{}
|
||||
for _, shortcut := range shortcutsUserSetting.GetShortcuts() {
|
||||
shortcuts = append(shortcuts, &v1pb.Shortcut{
|
||||
Name: constructShortcutName(userID, shortcut.GetId()),
|
||||
Name: constructShortcutName(user.Username, shortcut.GetId()),
|
||||
Title: shortcut.GetTitle(),
|
||||
Filter: shortcut.GetFilter(),
|
||||
})
|
||||
|
|
@ -86,10 +93,11 @@ func (s *APIV1Service) ListShortcuts(ctx context.Context, request *v1pb.ListShor
|
|||
}
|
||||
|
||||
func (s *APIV1Service) GetShortcut(ctx context.Context, request *v1pb.GetShortcutRequest) (*v1pb.Shortcut, error) {
|
||||
userID, shortcutID, err := extractUserAndShortcutIDFromName(request.Name)
|
||||
user, shortcutID, err := s.extractUserAndShortcutIDFromName(ctx, request.Name)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.InvalidArgument, "invalid shortcut name: %v", err)
|
||||
}
|
||||
userID := user.ID
|
||||
|
||||
currentUser, err := s.fetchCurrentUser(ctx)
|
||||
if err != nil {
|
||||
|
|
@ -114,7 +122,7 @@ func (s *APIV1Service) GetShortcut(ctx context.Context, request *v1pb.GetShortcu
|
|||
for _, shortcut := range shortcutsUserSetting.GetShortcuts() {
|
||||
if shortcut.GetId() == shortcutID {
|
||||
return &v1pb.Shortcut{
|
||||
Name: constructShortcutName(userID, shortcut.GetId()),
|
||||
Name: constructShortcutName(user.Username, shortcut.GetId()),
|
||||
Title: shortcut.GetTitle(),
|
||||
Filter: shortcut.GetFilter(),
|
||||
}, nil
|
||||
|
|
@ -125,10 +133,14 @@ func (s *APIV1Service) GetShortcut(ctx context.Context, request *v1pb.GetShortcu
|
|||
}
|
||||
|
||||
func (s *APIV1Service) CreateShortcut(ctx context.Context, request *v1pb.CreateShortcutRequest) (*v1pb.Shortcut, error) {
|
||||
userID, err := ExtractUserIDFromName(request.Parent)
|
||||
user, err := ResolveUserByName(ctx, s.Store, request.Parent)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.InvalidArgument, "invalid user name: %v", err)
|
||||
}
|
||||
if user == nil {
|
||||
return nil, status.Errorf(codes.NotFound, "user not found")
|
||||
}
|
||||
userID := user.ID
|
||||
|
||||
currentUser, err := s.fetchCurrentUser(ctx)
|
||||
if err != nil {
|
||||
|
|
@ -151,7 +163,7 @@ func (s *APIV1Service) CreateShortcut(ctx context.Context, request *v1pb.CreateS
|
|||
}
|
||||
if request.ValidateOnly {
|
||||
return &v1pb.Shortcut{
|
||||
Name: constructShortcutName(userID, newShortcut.GetId()),
|
||||
Name: constructShortcutName(user.Username, newShortcut.GetId()),
|
||||
Title: newShortcut.GetTitle(),
|
||||
Filter: newShortcut.GetFilter(),
|
||||
}, nil
|
||||
|
|
@ -190,17 +202,18 @@ func (s *APIV1Service) CreateShortcut(ctx context.Context, request *v1pb.CreateS
|
|||
}
|
||||
|
||||
return &v1pb.Shortcut{
|
||||
Name: constructShortcutName(userID, newShortcut.GetId()),
|
||||
Name: constructShortcutName(user.Username, newShortcut.GetId()),
|
||||
Title: newShortcut.GetTitle(),
|
||||
Filter: newShortcut.GetFilter(),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *APIV1Service) UpdateShortcut(ctx context.Context, request *v1pb.UpdateShortcutRequest) (*v1pb.Shortcut, error) {
|
||||
userID, shortcutID, err := extractUserAndShortcutIDFromName(request.Shortcut.Name)
|
||||
user, shortcutID, err := s.extractUserAndShortcutIDFromName(ctx, request.Shortcut.Name)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.InvalidArgument, "invalid shortcut name: %v", err)
|
||||
}
|
||||
userID := user.ID
|
||||
|
||||
currentUser, err := s.fetchCurrentUser(ctx)
|
||||
if err != nil {
|
||||
|
|
@ -262,17 +275,18 @@ func (s *APIV1Service) UpdateShortcut(ctx context.Context, request *v1pb.UpdateS
|
|||
}
|
||||
|
||||
return &v1pb.Shortcut{
|
||||
Name: constructShortcutName(userID, foundShortcut.GetId()),
|
||||
Name: constructShortcutName(user.Username, foundShortcut.GetId()),
|
||||
Title: foundShortcut.GetTitle(),
|
||||
Filter: foundShortcut.GetFilter(),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *APIV1Service) DeleteShortcut(ctx context.Context, request *v1pb.DeleteShortcutRequest) (*emptypb.Empty, error) {
|
||||
userID, shortcutID, err := extractUserAndShortcutIDFromName(request.Name)
|
||||
user, shortcutID, err := s.extractUserAndShortcutIDFromName(ctx, request.Name)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.InvalidArgument, "invalid shortcut name: %v", err)
|
||||
}
|
||||
userID := user.ID
|
||||
|
||||
currentUser, err := s.fetchCurrentUser(ctx)
|
||||
if err != nil {
|
||||
|
|
|
|||
|
|
@ -0,0 +1,40 @@
|
|||
package v1
|
||||
|
||||
import "github.com/usememos/memos/store"
|
||||
|
||||
func buildMemoName(uid string) string {
|
||||
return MemoNamePrefix + uid
|
||||
}
|
||||
|
||||
// resolveSSECreatorID returns the CreatorID used for SSE delivery filtering.
|
||||
// For a comment memo, it returns the parent memo's CreatorID so that private
|
||||
// parent-memo events are scoped to the parent owner.
|
||||
func resolveSSECreatorID(memo *store.Memo, parentMemo *store.Memo) int32 {
|
||||
if memo == nil {
|
||||
return 0
|
||||
}
|
||||
if parentMemo != nil {
|
||||
return parentMemo.CreatorID
|
||||
}
|
||||
return memo.CreatorID
|
||||
}
|
||||
|
||||
// buildMemoReactionSSEEvent constructs an SSEEvent for a reaction on a memo.
|
||||
// Pass parentMemo when the memo is a comment (memo.ParentUID != nil).
|
||||
func buildMemoReactionSSEEvent(eventType SSEEventType, contentID string, memo *store.Memo, parentMemo *store.Memo) *SSEEvent {
|
||||
parent := ""
|
||||
if memo != nil && memo.ParentUID != nil {
|
||||
parent = buildMemoName(*memo.ParentUID)
|
||||
}
|
||||
visibility := store.Visibility("")
|
||||
if memo != nil {
|
||||
visibility = memo.Visibility
|
||||
}
|
||||
return &SSEEvent{
|
||||
Type: eventType,
|
||||
Name: contentID,
|
||||
Parent: parent,
|
||||
Visibility: visibility,
|
||||
CreatorID: resolveSSECreatorID(memo, parentMemo),
|
||||
}
|
||||
}
|
||||
|
|
@ -17,10 +17,14 @@ const (
|
|||
sseHeartbeatInterval = 30 * time.Second
|
||||
)
|
||||
|
||||
// RegisterSSERoutes registers the SSE endpoint on the given Echo instance.
|
||||
func RegisterSSERoutes(echoServer *echo.Echo, hub *SSEHub, storeInstance *store.Store, secret string) {
|
||||
type sseRouteRegistrar interface {
|
||||
GET(path string, h echo.HandlerFunc, m ...echo.MiddlewareFunc) echo.RouteInfo
|
||||
}
|
||||
|
||||
// RegisterSSERoutes registers the SSE endpoint on the given Echo router.
|
||||
func RegisterSSERoutes(router sseRouteRegistrar, hub *SSEHub, storeInstance *store.Store, secret string) {
|
||||
authenticator := auth.NewAuthenticator(storeInstance, secret)
|
||||
echoServer.GET("/api/v1/sse", func(c *echo.Context) error {
|
||||
router.GET("/api/v1/sse", func(c *echo.Context) error {
|
||||
return handleSSE(c, hub, authenticator)
|
||||
})
|
||||
}
|
||||
|
|
@ -34,6 +38,10 @@ func handleSSE(c *echo.Context, hub *SSEHub, authenticator *auth.Authenticator)
|
|||
if result == nil {
|
||||
return c.JSON(http.StatusUnauthorized, map[string]string{"error": "authentication required"})
|
||||
}
|
||||
userID, role := getSSEClientIdentity(result)
|
||||
if userID == 0 {
|
||||
return c.JSON(http.StatusUnauthorized, map[string]string{"error": "authentication required"})
|
||||
}
|
||||
|
||||
// Set SSE headers.
|
||||
w := c.Response()
|
||||
|
|
@ -49,7 +57,7 @@ func handleSSE(c *echo.Context, hub *SSEHub, authenticator *auth.Authenticator)
|
|||
}
|
||||
|
||||
// Subscribe to the hub.
|
||||
client := hub.Subscribe()
|
||||
client := hub.Subscribe(userID, role)
|
||||
defer hub.Unsubscribe(client)
|
||||
|
||||
// Create a ticker for heartbeat pings.
|
||||
|
|
@ -58,13 +66,13 @@ func handleSSE(c *echo.Context, hub *SSEHub, authenticator *auth.Authenticator)
|
|||
|
||||
ctx := c.Request().Context()
|
||||
|
||||
slog.Debug("SSE client connected")
|
||||
slog.Debug("SSE client connected", "userID", userID)
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
// Client disconnected.
|
||||
slog.Debug("SSE client disconnected")
|
||||
slog.Debug("SSE client disconnected", "userID", userID)
|
||||
return nil
|
||||
|
||||
case data, ok := <-client.events:
|
||||
|
|
@ -91,3 +99,16 @@ func handleSSE(c *echo.Context, hub *SSEHub, authenticator *auth.Authenticator)
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
func getSSEClientIdentity(result *auth.AuthResult) (int32, store.Role) {
|
||||
if result == nil {
|
||||
return 0, store.RoleUser
|
||||
}
|
||||
if result.Claims != nil {
|
||||
return result.Claims.UserID, store.Role(result.Claims.Role)
|
||||
}
|
||||
if result.User != nil {
|
||||
return result.User.ID, result.User.Role
|
||||
}
|
||||
return 0, store.RoleUser
|
||||
}
|
||||
|
|
|
|||
|
|
@ -4,6 +4,8 @@ import (
|
|||
"encoding/json"
|
||||
"log/slog"
|
||||
"sync"
|
||||
|
||||
"github.com/usememos/memos/store"
|
||||
)
|
||||
|
||||
// SSEEventType represents the type of change event.
|
||||
|
|
@ -24,6 +26,11 @@ type SSEEvent struct {
|
|||
// Name is the affected resource name (e.g., "memos/xxxx").
|
||||
// For reaction events, this is the memo resource name that the reaction belongs to.
|
||||
Name string `json:"name"`
|
||||
// Parent is the parent memo resource name when the affected resource is a comment.
|
||||
Parent string `json:"parent,omitempty"`
|
||||
// Visibility and CreatorID are used only for server-side delivery filtering.
|
||||
Visibility store.Visibility `json:"-"`
|
||||
CreatorID int32 `json:"-"`
|
||||
}
|
||||
|
||||
// JSON returns the JSON representation of the event.
|
||||
|
|
@ -40,6 +47,8 @@ func (e *SSEEvent) JSON() []byte {
|
|||
// SSEClient represents a single SSE connection.
|
||||
type SSEClient struct {
|
||||
events chan []byte
|
||||
userID int32
|
||||
role store.Role
|
||||
}
|
||||
|
||||
// SSEHub manages SSE client connections and broadcasts events.
|
||||
|
|
@ -58,10 +67,12 @@ func NewSSEHub() *SSEHub {
|
|||
|
||||
// Subscribe registers a new client and returns it.
|
||||
// The caller must call Unsubscribe when done.
|
||||
func (h *SSEHub) Subscribe() *SSEClient {
|
||||
func (h *SSEHub) Subscribe(userID int32, role store.Role) *SSEClient {
|
||||
c := &SSEClient{
|
||||
// Buffer a few events so a slow client doesn't block broadcasting.
|
||||
events: make(chan []byte, 32),
|
||||
userID: userID,
|
||||
role: role,
|
||||
}
|
||||
h.mu.Lock()
|
||||
h.clients[c] = struct{}{}
|
||||
|
|
@ -90,6 +101,9 @@ func (h *SSEHub) Broadcast(event *SSEEvent) {
|
|||
h.mu.RLock()
|
||||
defer h.mu.RUnlock()
|
||||
for c := range h.clients {
|
||||
if !c.canReceive(event) {
|
||||
continue
|
||||
}
|
||||
select {
|
||||
case c.events <- data:
|
||||
default:
|
||||
|
|
@ -97,3 +111,15 @@ func (h *SSEHub) Broadcast(event *SSEEvent) {
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (c *SSEClient) canReceive(event *SSEEvent) bool {
|
||||
switch event.Visibility {
|
||||
case store.Private:
|
||||
return c.userID == event.CreatorID || c.role == store.RoleAdmin
|
||||
case store.Public, store.Protected, "":
|
||||
return true
|
||||
default:
|
||||
slog.Warn("SSE canReceive: unknown visibility type, denying event", "visibility", event.Visibility)
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -6,12 +6,36 @@ import (
|
|||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/usememos/memos/store"
|
||||
)
|
||||
|
||||
// helpers shared by multiple tests in this file.
|
||||
|
||||
func mustReceive(t *testing.T, ch <-chan []byte, within time.Duration) []byte {
|
||||
t.Helper()
|
||||
select {
|
||||
case data := <-ch:
|
||||
return data
|
||||
case <-time.After(within):
|
||||
t.Fatal("timed out waiting for SSE event")
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func mustNotReceive(t *testing.T, ch <-chan []byte, within time.Duration) {
|
||||
t.Helper()
|
||||
select {
|
||||
case data := <-ch:
|
||||
t.Fatalf("unexpected SSE event received: %s", data)
|
||||
case <-time.After(within):
|
||||
}
|
||||
}
|
||||
|
||||
func TestSSEHub_SubscribeUnsubscribe(t *testing.T) {
|
||||
hub := NewSSEHub()
|
||||
|
||||
client := hub.Subscribe()
|
||||
client := hub.Subscribe(1, store.RoleUser)
|
||||
require.NotNil(t, client)
|
||||
require.NotNil(t, client.events)
|
||||
|
||||
|
|
@ -25,7 +49,7 @@ func TestSSEHub_SubscribeUnsubscribe(t *testing.T) {
|
|||
|
||||
func TestSSEHub_Broadcast(t *testing.T) {
|
||||
hub := NewSSEHub()
|
||||
client := hub.Subscribe()
|
||||
client := hub.Subscribe(1, store.RoleUser)
|
||||
defer hub.Unsubscribe(client)
|
||||
|
||||
event := &SSEEvent{Type: SSEEventMemoCreated, Name: "memos/123"}
|
||||
|
|
@ -42,9 +66,9 @@ func TestSSEHub_Broadcast(t *testing.T) {
|
|||
|
||||
func TestSSEHub_BroadcastMultipleClients(t *testing.T) {
|
||||
hub := NewSSEHub()
|
||||
c1 := hub.Subscribe()
|
||||
c1 := hub.Subscribe(1, store.RoleUser)
|
||||
defer hub.Unsubscribe(c1)
|
||||
c2 := hub.Subscribe()
|
||||
c2 := hub.Subscribe(2, store.RoleUser)
|
||||
defer hub.Unsubscribe(c2)
|
||||
|
||||
event := &SSEEvent{Type: SSEEventMemoDeleted, Name: "memos/456"}
|
||||
|
|
@ -62,9 +86,144 @@ func TestSSEHub_BroadcastMultipleClients(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestSSEEvent_JSON(t *testing.T) {
|
||||
e := &SSEEvent{Type: SSEEventMemoUpdated, Name: "memos/789"}
|
||||
e := &SSEEvent{Type: SSEEventMemoUpdated, Name: "memos/789", Parent: "memos/123"}
|
||||
data := e.JSON()
|
||||
require.NotEmpty(t, data)
|
||||
assert.Contains(t, string(data), `"type":"memo.updated"`)
|
||||
assert.Contains(t, string(data), `"name":"memos/789"`)
|
||||
assert.Contains(t, string(data), `"parent":"memos/123"`)
|
||||
}
|
||||
|
||||
func TestSSEHub_PrivateEventsAreScoped(t *testing.T) {
|
||||
hub := NewSSEHub()
|
||||
owner := hub.Subscribe(1, store.RoleUser)
|
||||
defer hub.Unsubscribe(owner)
|
||||
other := hub.Subscribe(2, store.RoleUser)
|
||||
defer hub.Unsubscribe(other)
|
||||
admin := hub.Subscribe(3, store.RoleAdmin)
|
||||
defer hub.Unsubscribe(admin)
|
||||
|
||||
hub.Broadcast(&SSEEvent{
|
||||
Type: SSEEventMemoUpdated,
|
||||
Name: "memos/private",
|
||||
Visibility: store.Private,
|
||||
CreatorID: 1,
|
||||
})
|
||||
|
||||
select {
|
||||
case <-owner.events:
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("owner should receive private event")
|
||||
}
|
||||
|
||||
select {
|
||||
case <-admin.events:
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("admin should receive private event")
|
||||
}
|
||||
|
||||
select {
|
||||
case <-other.events:
|
||||
t.Fatal("non-owner should not receive private event")
|
||||
case <-time.After(100 * time.Millisecond):
|
||||
}
|
||||
}
|
||||
|
||||
func TestSSEClient_CanReceive_UnknownVisibility(t *testing.T) {
|
||||
hub := NewSSEHub()
|
||||
client := hub.Subscribe(1, store.RoleUser)
|
||||
defer hub.Unsubscribe(client)
|
||||
|
||||
// An event with an unrecognised visibility value should be denied (safe default).
|
||||
hub.Broadcast(&SSEEvent{
|
||||
Type: SSEEventMemoUpdated,
|
||||
Name: "memos/unknown-vis",
|
||||
Visibility: store.Visibility("CUSTOM"),
|
||||
})
|
||||
|
||||
mustNotReceive(t, client.events, 100*time.Millisecond)
|
||||
}
|
||||
|
||||
func TestSSEHub_SlowClientEventsDropped(t *testing.T) {
|
||||
hub := NewSSEHub()
|
||||
// Subscribe but never read, so the channel fills up.
|
||||
slow := hub.Subscribe(1, store.RoleUser)
|
||||
defer hub.Unsubscribe(slow)
|
||||
|
||||
event := &SSEEvent{Type: SSEEventMemoCreated, Name: "memos/x"}
|
||||
// Send more events than the buffer capacity (32).
|
||||
for range 40 {
|
||||
hub.Broadcast(event) // must not block
|
||||
}
|
||||
|
||||
// At most 32 events should have been queued; the rest were silently dropped.
|
||||
assert.LessOrEqual(t, len(slow.events), 32)
|
||||
}
|
||||
|
||||
func TestResolveSSECreatorID(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
memo *store.Memo
|
||||
parentMemo *store.Memo
|
||||
want int32
|
||||
}{
|
||||
{
|
||||
name: "nil memo returns 0",
|
||||
memo: nil, parentMemo: nil,
|
||||
want: 0,
|
||||
},
|
||||
{
|
||||
name: "memo without parent returns memo CreatorID",
|
||||
memo: &store.Memo{CreatorID: 5},
|
||||
parentMemo: nil,
|
||||
want: 5,
|
||||
},
|
||||
{
|
||||
name: "memo with parent returns parent CreatorID",
|
||||
memo: &store.Memo{CreatorID: 5},
|
||||
parentMemo: &store.Memo{CreatorID: 9},
|
||||
want: 9,
|
||||
},
|
||||
}
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
assert.Equal(t, tc.want, resolveSSECreatorID(tc.memo, tc.parentMemo))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildMemoReactionSSEEvent(t *testing.T) {
|
||||
parentUID := "parent-uid"
|
||||
|
||||
t.Run("top-level memo reaction", func(t *testing.T) {
|
||||
memo := &store.Memo{CreatorID: 10, Visibility: store.Public}
|
||||
event := buildMemoReactionSSEEvent(SSEEventReactionUpserted, "memos/abc", memo, nil)
|
||||
assert.Equal(t, SSEEventReactionUpserted, event.Type)
|
||||
assert.Equal(t, "memos/abc", event.Name)
|
||||
assert.Equal(t, "", event.Parent)
|
||||
assert.Equal(t, store.Public, event.Visibility)
|
||||
assert.Equal(t, int32(10), event.CreatorID)
|
||||
})
|
||||
|
||||
t.Run("reaction on comment is scoped to parent owner", func(t *testing.T) {
|
||||
memo := &store.Memo{
|
||||
CreatorID: 10,
|
||||
Visibility: store.Private,
|
||||
ParentUID: &parentUID,
|
||||
}
|
||||
parentMemo := &store.Memo{CreatorID: 7}
|
||||
event := buildMemoReactionSSEEvent(SSEEventReactionDeleted, "memos/abc", memo, parentMemo)
|
||||
assert.Equal(t, SSEEventReactionDeleted, event.Type)
|
||||
assert.Equal(t, MemoNamePrefix+parentUID, event.Parent)
|
||||
assert.Equal(t, store.Private, event.Visibility)
|
||||
assert.Equal(t, int32(7), event.CreatorID)
|
||||
})
|
||||
|
||||
t.Run("nil memo produces a safe zero-value event", func(t *testing.T) {
|
||||
event := buildMemoReactionSSEEvent(SSEEventReactionUpserted, "memos/abc", nil, nil)
|
||||
assert.Equal(t, "memos/abc", event.Name)
|
||||
assert.Equal(t, "", event.Parent)
|
||||
assert.Equal(t, store.Visibility(""), event.Visibility)
|
||||
assert.Equal(t, int32(0), event.CreatorID)
|
||||
})
|
||||
}
|
||||
|
|
|
|||
|
|
@ -0,0 +1,272 @@
|
|||
package v1
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/usememos/memos/internal/profile"
|
||||
v1pb "github.com/usememos/memos/proto/gen/api/v1"
|
||||
"github.com/usememos/memos/server/auth"
|
||||
"github.com/usememos/memos/store"
|
||||
teststore "github.com/usememos/memos/store/test"
|
||||
)
|
||||
|
||||
// newIntegrationService builds a minimal APIV1Service backed by an in-memory
|
||||
// SQLite database. The store is closed automatically via t.Cleanup.
|
||||
func newIntegrationService(t *testing.T) *APIV1Service {
|
||||
t.Helper()
|
||||
ctx := context.Background()
|
||||
st := teststore.NewTestingStore(ctx, t)
|
||||
t.Cleanup(func() { st.Close() })
|
||||
p := &profile.Profile{Demo: true, Data: t.TempDir(), Driver: "sqlite", DSN: ":memory:"}
|
||||
return NewAPIV1Service("test-secret", p, st)
|
||||
}
|
||||
|
||||
// userCtx returns a context that authenticates as the given user.
|
||||
func userCtx(ctx context.Context, userID int32) context.Context {
|
||||
return context.WithValue(ctx, auth.UserIDContextKey, userID)
|
||||
}
|
||||
|
||||
// collectEventsFor reads events from ch for the given duration and returns them.
|
||||
func collectEventsFor(ch <-chan []byte, d time.Duration) []string {
|
||||
var out []string
|
||||
deadline := time.After(d)
|
||||
for {
|
||||
select {
|
||||
case data := <-ch:
|
||||
out = append(out, string(data))
|
||||
case <-deadline:
|
||||
return out
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ---- context suppression ----
|
||||
|
||||
func TestSuppressSSEContext(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
t.Run("default context is not suppressed", func(t *testing.T) {
|
||||
assert.False(t, isSSESuppressed(ctx))
|
||||
})
|
||||
|
||||
t.Run("withSuppressSSE marks context as suppressed", func(t *testing.T) {
|
||||
assert.True(t, isSSESuppressed(withSuppressSSE(ctx)))
|
||||
})
|
||||
|
||||
t.Run("suppression does not bleed into parent context", func(t *testing.T) {
|
||||
suppressed := withSuppressSSE(ctx)
|
||||
_ = suppressed
|
||||
assert.False(t, isSSESuppressed(ctx))
|
||||
})
|
||||
}
|
||||
|
||||
// ---- CreateMemoComment double-broadcast fix ----
|
||||
|
||||
func TestCreateMemoComment_NoDuplicateSSEBroadcast(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
svc := newIntegrationService(t)
|
||||
|
||||
// Create an admin so the store is initialised, then a regular commenter.
|
||||
author, err := svc.Store.CreateUser(ctx, &store.User{
|
||||
Username: "author", Role: store.RoleAdmin, Email: "author@example.com",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
commenter, err := svc.Store.CreateUser(ctx, &store.User{
|
||||
Username: "commenter", Role: store.RoleUser, Email: "commenter@example.com",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
authorCtx := userCtx(ctx, author.ID)
|
||||
commenterCtx := userCtx(ctx, commenter.ID)
|
||||
|
||||
// Create a public memo so the commenter can react.
|
||||
parent, err := svc.CreateMemo(authorCtx, &v1pb.CreateMemoRequest{
|
||||
Memo: &v1pb.Memo{Content: "parent memo", Visibility: v1pb.Visibility_PUBLIC},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Subscribe after the parent memo is created so the memo.created event
|
||||
// for the parent does not pollute the assertion window.
|
||||
client := svc.SSEHub.Subscribe(author.ID, store.RoleAdmin)
|
||||
defer svc.SSEHub.Unsubscribe(client)
|
||||
|
||||
// Create a comment. Before the fix, this fired both memo.created (for the
|
||||
// comment memo) and memo.comment.created (for the parent).
|
||||
_, err = svc.CreateMemoComment(commenterCtx, &v1pb.CreateMemoCommentRequest{
|
||||
Name: parent.Name,
|
||||
Comment: &v1pb.Memo{Content: "a comment", Visibility: v1pb.Visibility_PUBLIC},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Give the synchronous broadcast a moment to land in the buffer, then
|
||||
// collect everything that arrived.
|
||||
events := collectEventsFor(client.events, 150*time.Millisecond)
|
||||
|
||||
require.Len(t, events, 1, "expected exactly one SSE event for a comment creation, got: %v", events)
|
||||
assert.True(t, strings.Contains(events[0], `"memo.comment.created"`),
|
||||
"expected memo.comment.created, got: %s", events[0])
|
||||
}
|
||||
|
||||
// ---- Reaction SSE events carry correct visibility / parent fields ----
|
||||
|
||||
func TestUpsertMemoReaction_SSEEvent(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
svc := newIntegrationService(t)
|
||||
|
||||
user, err := svc.Store.CreateUser(ctx, &store.User{
|
||||
Username: "user", Role: store.RoleAdmin, Email: "user@example.com",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
uctx := userCtx(ctx, user.ID)
|
||||
|
||||
memo, err := svc.CreateMemo(uctx, &v1pb.CreateMemoRequest{
|
||||
Memo: &v1pb.Memo{Content: "reacted memo", Visibility: v1pb.Visibility_PUBLIC},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
client := svc.SSEHub.Subscribe(user.ID, store.RoleAdmin)
|
||||
defer svc.SSEHub.Unsubscribe(client)
|
||||
|
||||
_, err = svc.UpsertMemoReaction(uctx, &v1pb.UpsertMemoReactionRequest{
|
||||
Name: memo.Name,
|
||||
Reaction: &v1pb.Reaction{
|
||||
ContentId: memo.Name,
|
||||
ReactionType: "👍",
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
data := mustReceive(t, client.events, time.Second)
|
||||
payload := string(data)
|
||||
assert.Contains(t, payload, `"reaction.upserted"`)
|
||||
assert.Contains(t, payload, memo.Name)
|
||||
mustNotReceive(t, client.events, 100*time.Millisecond)
|
||||
}
|
||||
|
||||
func TestDeleteMemoReaction_SSEEvent(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
svc := newIntegrationService(t)
|
||||
|
||||
user, err := svc.Store.CreateUser(ctx, &store.User{
|
||||
Username: "user", Role: store.RoleAdmin, Email: "user@example.com",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
uctx := userCtx(ctx, user.ID)
|
||||
|
||||
memo, err := svc.CreateMemo(uctx, &v1pb.CreateMemoRequest{
|
||||
Memo: &v1pb.Memo{Content: "reacted memo", Visibility: v1pb.Visibility_PUBLIC},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
reaction, err := svc.UpsertMemoReaction(uctx, &v1pb.UpsertMemoReactionRequest{
|
||||
Name: memo.Name,
|
||||
Reaction: &v1pb.Reaction{
|
||||
ContentId: memo.Name,
|
||||
ReactionType: "❤️",
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
client := svc.SSEHub.Subscribe(user.ID, store.RoleAdmin)
|
||||
defer svc.SSEHub.Unsubscribe(client)
|
||||
|
||||
_, err = svc.DeleteMemoReaction(uctx, &v1pb.DeleteMemoReactionRequest{
|
||||
Name: reaction.Name,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
data := mustReceive(t, client.events, time.Second)
|
||||
payload := string(data)
|
||||
assert.Contains(t, payload, `"reaction.deleted"`)
|
||||
assert.Contains(t, payload, memo.Name)
|
||||
mustNotReceive(t, client.events, 100*time.Millisecond)
|
||||
}
|
||||
|
||||
func TestSetMemoAttachments_EmitsMemoUpdatedSSEEvent(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
svc := newIntegrationService(t)
|
||||
|
||||
user, err := svc.Store.CreateUser(ctx, &store.User{
|
||||
Username: "user", Role: store.RoleAdmin, Email: "user@example.com",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
uctx := userCtx(ctx, user.ID)
|
||||
|
||||
memo, err := svc.CreateMemo(uctx, &v1pb.CreateMemoRequest{
|
||||
Memo: &v1pb.Memo{Content: "memo with attachments", Visibility: v1pb.Visibility_PUBLIC},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
attachment, err := svc.CreateAttachment(uctx, &v1pb.CreateAttachmentRequest{
|
||||
Attachment: &v1pb.Attachment{
|
||||
Filename: "test.txt",
|
||||
Size: 5,
|
||||
Type: "text/plain",
|
||||
Content: []byte("hello"),
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
client := svc.SSEHub.Subscribe(user.ID, store.RoleAdmin)
|
||||
defer svc.SSEHub.Unsubscribe(client)
|
||||
|
||||
_, err = svc.SetMemoAttachments(uctx, &v1pb.SetMemoAttachmentsRequest{
|
||||
Name: memo.Name,
|
||||
Attachments: []*v1pb.Attachment{
|
||||
{Name: attachment.Name},
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
data := mustReceive(t, client.events, time.Second)
|
||||
payload := string(data)
|
||||
assert.Contains(t, payload, `"memo.updated"`)
|
||||
assert.Contains(t, payload, memo.Name)
|
||||
mustNotReceive(t, client.events, 100*time.Millisecond)
|
||||
}
|
||||
|
||||
func TestSetMemoRelations_EmitsMemoUpdatedSSEEvent(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
svc := newIntegrationService(t)
|
||||
|
||||
user, err := svc.Store.CreateUser(ctx, &store.User{
|
||||
Username: "user", Role: store.RoleAdmin, Email: "user@example.com",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
uctx := userCtx(ctx, user.ID)
|
||||
|
||||
memo1, err := svc.CreateMemo(uctx, &v1pb.CreateMemoRequest{
|
||||
Memo: &v1pb.Memo{Content: "memo one", Visibility: v1pb.Visibility_PUBLIC},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
memo2, err := svc.CreateMemo(uctx, &v1pb.CreateMemoRequest{
|
||||
Memo: &v1pb.Memo{Content: "memo two", Visibility: v1pb.Visibility_PUBLIC},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
client := svc.SSEHub.Subscribe(user.ID, store.RoleAdmin)
|
||||
defer svc.SSEHub.Unsubscribe(client)
|
||||
|
||||
_, err = svc.SetMemoRelations(uctx, &v1pb.SetMemoRelationsRequest{
|
||||
Name: memo1.Name,
|
||||
Relations: []*v1pb.MemoRelation{
|
||||
{
|
||||
RelatedMemo: &v1pb.MemoRelation_Memo{Name: memo2.Name},
|
||||
Type: v1pb.MemoRelation_REFERENCE,
|
||||
},
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
data := mustReceive(t, client.events, time.Second)
|
||||
payload := string(data)
|
||||
assert.Contains(t, payload, `"memo.updated"`)
|
||||
assert.Contains(t, payload, memo1.Name)
|
||||
mustNotReceive(t, client.events, 100*time.Millisecond)
|
||||
}
|
||||
|
|
@ -7,6 +7,9 @@ import (
|
|||
"github.com/stretchr/testify/require"
|
||||
|
||||
v1pb "github.com/usememos/memos/proto/gen/api/v1"
|
||||
storepb "github.com/usememos/memos/proto/gen/store"
|
||||
apiv1 "github.com/usememos/memos/server/router/api/v1"
|
||||
"github.com/usememos/memos/store"
|
||||
)
|
||||
|
||||
func TestCreateAttachment(t *testing.T) {
|
||||
|
|
@ -56,4 +59,56 @@ func TestCreateAttachment(t *testing.T) {
|
|||
require.NoError(t, err)
|
||||
require.Equal(t, "application/octet-stream", attachment.Type)
|
||||
})
|
||||
|
||||
t.Run("LocalStorage_PathCollisionUsesUniqueReference", func(t *testing.T) {
|
||||
_, err := ts.Store.UpsertInstanceSetting(ctx, &storepb.InstanceSetting{
|
||||
Key: storepb.InstanceSettingKey_STORAGE,
|
||||
Value: &storepb.InstanceSetting_StorageSetting{
|
||||
StorageSetting: &storepb.InstanceStorageSetting{
|
||||
StorageType: storepb.InstanceStorageSetting_LOCAL,
|
||||
FilepathTemplate: "assets/{filename}",
|
||||
},
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
first, err := ts.Service.CreateAttachment(userCtx, &v1pb.CreateAttachmentRequest{
|
||||
Attachment: &v1pb.Attachment{
|
||||
Filename: "screenshot.png",
|
||||
Type: "image/png",
|
||||
Content: []byte("first-image"),
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
second, err := ts.Service.CreateAttachment(userCtx, &v1pb.CreateAttachmentRequest{
|
||||
Attachment: &v1pb.Attachment{
|
||||
Filename: "screenshot.png",
|
||||
Type: "image/png",
|
||||
Content: []byte("second-image"),
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
firstUID, err := apiv1.ExtractAttachmentUIDFromName(first.Name)
|
||||
require.NoError(t, err)
|
||||
secondUID, err := apiv1.ExtractAttachmentUIDFromName(second.Name)
|
||||
require.NoError(t, err)
|
||||
|
||||
firstStoreAttachment, err := ts.Store.GetAttachment(ctx, &store.FindAttachment{UID: &firstUID})
|
||||
require.NoError(t, err)
|
||||
secondStoreAttachment, err := ts.Store.GetAttachment(ctx, &store.FindAttachment{UID: &secondUID})
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, firstStoreAttachment)
|
||||
require.NotNil(t, secondStoreAttachment)
|
||||
|
||||
require.NotEqual(t, firstStoreAttachment.Reference, secondStoreAttachment.Reference)
|
||||
|
||||
firstBlob, err := ts.Service.GetAttachmentBlob(firstStoreAttachment)
|
||||
require.NoError(t, err)
|
||||
secondBlob, err := ts.Service.GetAttachmentBlob(secondStoreAttachment)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, []byte("first-image"), firstBlob)
|
||||
require.Equal(t, []byte("second-image"), secondBlob)
|
||||
})
|
||||
}
|
||||
|
|
|
|||
|
|
@ -8,6 +8,8 @@ import (
|
|||
"google.golang.org/protobuf/types/known/fieldmaskpb"
|
||||
|
||||
v1pb "github.com/usememos/memos/proto/gen/api/v1"
|
||||
apiv1 "github.com/usememos/memos/server/router/api/v1"
|
||||
"github.com/usememos/memos/store"
|
||||
)
|
||||
|
||||
func TestCreateIdentityProvider(t *testing.T) {
|
||||
|
|
@ -233,7 +235,7 @@ func TestGetIdentityProvider(t *testing.T) {
|
|||
Name: created.Name,
|
||||
}
|
||||
|
||||
// Test unauthenticated, should not contain client secret
|
||||
// ClientSecret is write-only: never returned in responses, even to admins.
|
||||
resp, err := ts.Service.GetIdentityProvider(ctx, getReq)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, resp)
|
||||
|
|
@ -242,18 +244,16 @@ func TestGetIdentityProvider(t *testing.T) {
|
|||
require.Equal(t, v1pb.IdentityProvider_OAUTH2, resp.Type)
|
||||
require.NotNil(t, resp.Config.GetOauth2Config())
|
||||
require.Equal(t, "test-client", resp.Config.GetOauth2Config().ClientId)
|
||||
require.Equal(t, "", resp.Config.GetOauth2Config().ClientSecret)
|
||||
require.Empty(t, resp.Config.GetOauth2Config().ClientSecret,
|
||||
"ClientSecret must never be returned in responses")
|
||||
|
||||
// Test as host user, should contain client secret
|
||||
respHostUser, err := ts.Service.GetIdentityProvider(userCtx, getReq)
|
||||
// Same for admin: secret is still write-only.
|
||||
respAdmin, err := ts.Service.GetIdentityProvider(userCtx, getReq)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, respHostUser)
|
||||
require.Equal(t, created.Name, respHostUser.Name)
|
||||
require.Equal(t, "Test Provider", respHostUser.Title)
|
||||
require.Equal(t, v1pb.IdentityProvider_OAUTH2, respHostUser.Type)
|
||||
require.NotNil(t, respHostUser.Config.GetOauth2Config())
|
||||
require.Equal(t, "test-client", respHostUser.Config.GetOauth2Config().ClientId)
|
||||
require.Equal(t, "test-secret", respHostUser.Config.GetOauth2Config().ClientSecret)
|
||||
require.NotNil(t, respAdmin)
|
||||
require.Equal(t, "test-client", respAdmin.Config.GetOauth2Config().ClientId)
|
||||
require.Empty(t, respAdmin.Config.GetOauth2Config().ClientSecret,
|
||||
"ClientSecret must never be returned in responses, even to admins")
|
||||
})
|
||||
|
||||
t.Run("GetIdentityProvider not found", func(t *testing.T) {
|
||||
|
|
@ -361,6 +361,67 @@ func TestUpdateIdentityProvider(t *testing.T) {
|
|||
require.Equal(t, "updated-client", updated.Config.GetOauth2Config().ClientId)
|
||||
})
|
||||
|
||||
t.Run("UpdateIdentityProvider empty ClientSecret preserves existing credential", func(t *testing.T) {
|
||||
ts := NewTestService(t)
|
||||
defer ts.Cleanup()
|
||||
|
||||
hostUser, err := ts.CreateHostUser(ctx, "admin")
|
||||
require.NoError(t, err)
|
||||
userCtx := ts.CreateUserContext(ctx, hostUser.ID)
|
||||
|
||||
// Create IDP with a real secret.
|
||||
created, err := ts.Service.CreateIdentityProvider(userCtx, &v1pb.CreateIdentityProviderRequest{
|
||||
IdentityProvider: &v1pb.IdentityProvider{
|
||||
Title: "Preserve Secret Test",
|
||||
Type: v1pb.IdentityProvider_OAUTH2,
|
||||
Config: &v1pb.IdentityProviderConfig{
|
||||
Config: &v1pb.IdentityProviderConfig_Oauth2Config{
|
||||
Oauth2Config: &v1pb.OAuth2Config{
|
||||
ClientId: "cid",
|
||||
ClientSecret: "original-secret",
|
||||
AuthUrl: "https://ex.com/auth",
|
||||
TokenUrl: "https://ex.com/token",
|
||||
UserInfoUrl: "https://ex.com/user",
|
||||
FieldMapping: &v1pb.FieldMapping{Identifier: "id"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Update with empty ClientSecret (simulating UI that doesn't resend the secret).
|
||||
_, err = ts.Service.UpdateIdentityProvider(userCtx, &v1pb.UpdateIdentityProviderRequest{
|
||||
IdentityProvider: &v1pb.IdentityProvider{
|
||||
Name: created.Name,
|
||||
Title: "Updated Title",
|
||||
Type: v1pb.IdentityProvider_OAUTH2,
|
||||
Config: &v1pb.IdentityProviderConfig{
|
||||
Config: &v1pb.IdentityProviderConfig_Oauth2Config{
|
||||
Oauth2Config: &v1pb.OAuth2Config{
|
||||
ClientId: "cid",
|
||||
ClientSecret: "", // empty = preserve existing
|
||||
AuthUrl: "https://ex.com/auth",
|
||||
TokenUrl: "https://ex.com/token",
|
||||
UserInfoUrl: "https://ex.com/user",
|
||||
FieldMapping: &v1pb.FieldMapping{Identifier: "id"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
UpdateMask: &fieldmaskpb.FieldMask{Paths: []string{"title", "config"}},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify the stored secret was preserved by reading from store directly.
|
||||
uid, _ := apiv1.ExtractIdentityProviderUIDFromName(created.Name)
|
||||
stored, err := ts.Store.GetIdentityProvider(ctx, &store.FindIdentityProvider{UID: &uid})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "original-secret", stored.Config.GetOauth2Config().ClientSecret,
|
||||
"existing ClientSecret must be preserved when an empty value is sent")
|
||||
require.Equal(t, "Updated Title", stored.Name)
|
||||
})
|
||||
|
||||
t.Run("UpdateIdentityProvider missing update mask", func(t *testing.T) {
|
||||
ts := NewTestService(t)
|
||||
defer ts.Cleanup()
|
||||
|
|
|
|||
|
|
@ -204,21 +204,38 @@ func TestGetInstanceSetting(t *testing.T) {
|
|||
require.Empty(t, resp.GetTagsSetting().GetTags())
|
||||
})
|
||||
|
||||
t.Run("GetInstanceSetting - notification setting", func(t *testing.T) {
|
||||
t.Run("GetInstanceSetting - notification setting requires admin", func(t *testing.T) {
|
||||
ts := NewTestService(t)
|
||||
defer ts.Cleanup()
|
||||
|
||||
req := &v1pb.GetInstanceSettingRequest{
|
||||
Name: "instance/settings/NOTIFICATION",
|
||||
}
|
||||
resp, err := ts.Service.GetInstanceSetting(ctx, req)
|
||||
admin, err := ts.CreateHostUser(ctx, "admin")
|
||||
require.NoError(t, err)
|
||||
adminCtx := ts.CreateUserContext(ctx, admin.ID)
|
||||
|
||||
regularUser, err := ts.CreateRegularUser(ctx, "user")
|
||||
require.NoError(t, err)
|
||||
userCtx := ts.CreateUserContext(ctx, regularUser.ID)
|
||||
|
||||
req := &v1pb.GetInstanceSettingRequest{Name: "instance/settings/NOTIFICATION"}
|
||||
|
||||
// Unauthenticated request must be rejected.
|
||||
_, err = ts.Service.GetInstanceSetting(ctx, req)
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "not authenticated")
|
||||
|
||||
// Non-admin request must be rejected.
|
||||
_, err = ts.Service.GetInstanceSetting(userCtx, req)
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "permission denied")
|
||||
|
||||
// Admin request succeeds and does NOT expose SmtpPassword.
|
||||
resp, err := ts.Service.GetInstanceSetting(adminCtx, req)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, resp)
|
||||
require.Equal(t, "instance/settings/NOTIFICATION", resp.Name)
|
||||
require.NotNil(t, resp.GetNotificationSetting())
|
||||
require.NotNil(t, resp.GetNotificationSetting().GetEmail())
|
||||
require.False(t, resp.GetNotificationSetting().GetEmail().GetEnabled())
|
||||
require.Empty(t, resp.GetNotificationSetting().GetEmail().GetSmtpPassword(),
|
||||
"SmtpPassword must never be returned in responses")
|
||||
})
|
||||
|
||||
t.Run("GetInstanceSetting - invalid setting name", func(t *testing.T) {
|
||||
|
|
@ -301,7 +318,7 @@ func TestUpdateInstanceSetting(t *testing.T) {
|
|||
require.Contains(t, err.Error(), "invalid instance setting")
|
||||
})
|
||||
|
||||
t.Run("UpdateInstanceSetting - notification setting", func(t *testing.T) {
|
||||
t.Run("UpdateInstanceSetting - tags setting without color", func(t *testing.T) {
|
||||
ts := NewTestService(t)
|
||||
defer ts.Cleanup()
|
||||
|
||||
|
|
@ -309,6 +326,36 @@ func TestUpdateInstanceSetting(t *testing.T) {
|
|||
require.NoError(t, err)
|
||||
|
||||
resp, err := ts.Service.UpdateInstanceSetting(ts.CreateUserContext(ctx, hostUser.ID), &v1pb.UpdateInstanceSettingRequest{
|
||||
Setting: &v1pb.InstanceSetting{
|
||||
Name: "instance/settings/TAGS",
|
||||
Value: &v1pb.InstanceSetting_TagsSetting_{
|
||||
TagsSetting: &v1pb.InstanceSetting_TagsSetting{
|
||||
Tags: map[string]*v1pb.InstanceSetting_TagMetadata{
|
||||
"spoiler": {
|
||||
BlurContent: true,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, resp.GetTagsSetting())
|
||||
require.Contains(t, resp.GetTagsSetting().GetTags(), "spoiler")
|
||||
require.Nil(t, resp.GetTagsSetting().GetTags()["spoiler"].GetBackgroundColor())
|
||||
require.True(t, resp.GetTagsSetting().GetTags()["spoiler"].GetBlurContent())
|
||||
})
|
||||
|
||||
t.Run("UpdateInstanceSetting - notification setting password is write-only", func(t *testing.T) {
|
||||
ts := NewTestService(t)
|
||||
defer ts.Cleanup()
|
||||
|
||||
hostUser, err := ts.CreateHostUser(ctx, "admin")
|
||||
require.NoError(t, err)
|
||||
adminCtx := ts.CreateUserContext(ctx, hostUser.ID)
|
||||
|
||||
// Save notification setting with a password.
|
||||
resp, err := ts.Service.UpdateInstanceSetting(adminCtx, &v1pb.UpdateInstanceSettingRequest{
|
||||
Setting: &v1pb.InstanceSetting{
|
||||
Name: "instance/settings/NOTIFICATION",
|
||||
Value: &v1pb.InstanceSetting_NotificationSetting_{
|
||||
|
|
@ -330,9 +377,117 @@ func TestUpdateInstanceSetting(t *testing.T) {
|
|||
UpdateMask: &fieldmaskpb.FieldMask{Paths: []string{"notification_setting"}},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, resp.GetNotificationSetting())
|
||||
require.NotNil(t, resp.GetNotificationSetting().GetEmail())
|
||||
require.True(t, resp.GetNotificationSetting().GetEmail().GetEnabled())
|
||||
require.Equal(t, "smtp.example.com", resp.GetNotificationSetting().GetEmail().GetSmtpHost())
|
||||
// Password must not be returned even in the update response.
|
||||
require.Empty(t, resp.GetNotificationSetting().GetEmail().GetSmtpPassword(),
|
||||
"SmtpPassword must never be returned in responses")
|
||||
})
|
||||
|
||||
t.Run("UpdateInstanceSetting - empty password preserves existing credential", func(t *testing.T) {
|
||||
ts := NewTestService(t)
|
||||
defer ts.Cleanup()
|
||||
|
||||
hostUser, err := ts.CreateHostUser(ctx, "admin")
|
||||
require.NoError(t, err)
|
||||
adminCtx := ts.CreateUserContext(ctx, hostUser.ID)
|
||||
|
||||
notificationSetting := &v1pb.InstanceSetting{
|
||||
Name: "instance/settings/NOTIFICATION",
|
||||
Value: &v1pb.InstanceSetting_NotificationSetting_{
|
||||
NotificationSetting: &v1pb.InstanceSetting_NotificationSetting{
|
||||
Email: &v1pb.InstanceSetting_NotificationSetting_EmailSetting{
|
||||
Enabled: true,
|
||||
SmtpHost: "smtp.example.com",
|
||||
SmtpPort: 587,
|
||||
SmtpUsername: "bot@example.com",
|
||||
SmtpPassword: "original-password",
|
||||
FromEmail: "bot@example.com",
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// First save with a real password.
|
||||
_, err = ts.Service.UpdateInstanceSetting(adminCtx, &v1pb.UpdateInstanceSettingRequest{
|
||||
Setting: notificationSetting,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Second update with an empty password (simulating a UI that doesn't re-send the secret).
|
||||
notificationSetting.GetNotificationSetting().GetEmail().SmtpPassword = ""
|
||||
notificationSetting.GetNotificationSetting().GetEmail().SmtpHost = "smtp2.example.com"
|
||||
_, err = ts.Service.UpdateInstanceSetting(adminCtx, &v1pb.UpdateInstanceSettingRequest{
|
||||
Setting: notificationSetting,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// The stored setting should have preserved the original password.
|
||||
stored, err := ts.Store.GetInstanceNotificationSetting(ctx)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "original-password", stored.GetEmail().GetSmtpPassword(),
|
||||
"existing SmtpPassword must be preserved when an empty value is sent")
|
||||
require.Equal(t, "smtp2.example.com", stored.GetEmail().GetSmtpHost())
|
||||
})
|
||||
|
||||
t.Run("UpdateInstanceSetting - S3 secret is write-only and preserved on empty", func(t *testing.T) {
|
||||
ts := NewTestService(t)
|
||||
defer ts.Cleanup()
|
||||
|
||||
hostUser, err := ts.CreateHostUser(ctx, "admin")
|
||||
require.NoError(t, err)
|
||||
adminCtx := ts.CreateUserContext(ctx, hostUser.ID)
|
||||
|
||||
// Save storage setting with a real secret.
|
||||
_, err = ts.Service.UpdateInstanceSetting(adminCtx, &v1pb.UpdateInstanceSettingRequest{
|
||||
Setting: &v1pb.InstanceSetting{
|
||||
Name: "instance/settings/STORAGE",
|
||||
Value: &v1pb.InstanceSetting_StorageSetting_{
|
||||
StorageSetting: &v1pb.InstanceSetting_StorageSetting{
|
||||
S3Config: &v1pb.InstanceSetting_StorageSetting_S3Config{
|
||||
AccessKeyId: "AKID",
|
||||
AccessKeySecret: "super-secret",
|
||||
Endpoint: "s3.example.com",
|
||||
Region: "us-east-1",
|
||||
Bucket: "memos",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Read back: secret must not be returned.
|
||||
resp, err := ts.Service.GetInstanceSetting(adminCtx, &v1pb.GetInstanceSettingRequest{
|
||||
Name: "instance/settings/STORAGE",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Empty(t, resp.GetStorageSetting().GetS3Config().GetAccessKeySecret(),
|
||||
"AccessKeySecret must never be returned in responses")
|
||||
|
||||
// Update with empty secret; original must be preserved in the store.
|
||||
_, err = ts.Service.UpdateInstanceSetting(adminCtx, &v1pb.UpdateInstanceSettingRequest{
|
||||
Setting: &v1pb.InstanceSetting{
|
||||
Name: "instance/settings/STORAGE",
|
||||
Value: &v1pb.InstanceSetting_StorageSetting_{
|
||||
StorageSetting: &v1pb.InstanceSetting_StorageSetting{
|
||||
S3Config: &v1pb.InstanceSetting_StorageSetting_S3Config{
|
||||
AccessKeyId: "AKID",
|
||||
AccessKeySecret: "", // omitted / not changed
|
||||
Endpoint: "s3-v2.example.com",
|
||||
Region: "us-east-1",
|
||||
Bucket: "memos",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
stored, err := ts.Store.GetInstanceStorageSetting(ctx)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "super-secret", stored.GetS3Config().GetAccessKeySecret(),
|
||||
"existing AccessKeySecret must be preserved when an empty value is sent")
|
||||
require.Equal(t, "s3-v2.example.com", stored.GetS3Config().GetEndpoint())
|
||||
})
|
||||
}
|
||||
|
|
|
|||
|
|
@ -154,7 +154,7 @@ func TestListMemos(t *testing.T) {
|
|||
memoOneRes := memos.Memos[memoOneResIdx]
|
||||
require.NotNil(t, memoOneRes)
|
||||
|
||||
require.Equal(t, fmt.Sprintf("users/%d", userOne.ID), memoOneRes.GetCreator())
|
||||
require.Equal(t, fmt.Sprintf("users/%s", userOne.Username), memoOneRes.GetCreator())
|
||||
require.Equal(t, apiv1.Visibility_PROTECTED, memoOneRes.GetVisibility())
|
||||
require.Equal(t, memoOne.Content, memoOneRes.GetContent())
|
||||
require.Equal(t, memoOne.Content[:64]+"...", memoOneRes.GetSnippet(), "memoOne's content is snipped past the 64 char limit")
|
||||
|
|
@ -202,7 +202,7 @@ func TestListMemos(t *testing.T) {
|
|||
memoTwoRes := memos.Memos[memoTwoResIdx]
|
||||
require.NotNil(t, memoTwoRes)
|
||||
|
||||
require.Equal(t, fmt.Sprintf("users/%d", userTwo.ID), memoTwoRes.GetCreator())
|
||||
require.Equal(t, fmt.Sprintf("users/%s", userTwo.Username), memoTwoRes.GetCreator())
|
||||
require.Equal(t, apiv1.Visibility_PROTECTED, memoTwoRes.GetVisibility())
|
||||
require.Equal(t, memoTwo.Content, memoTwoRes.GetContent())
|
||||
require.Empty(t, memoTwoRes.Attachments)
|
||||
|
|
@ -227,7 +227,7 @@ func TestListMemos(t *testing.T) {
|
|||
memoThreeRes := memos.Memos[memoThreeResIdx]
|
||||
require.NotNil(t, memoThreeRes)
|
||||
|
||||
require.Equal(t, fmt.Sprintf("users/%d", userOne.ID), memoThreeRes.GetCreator())
|
||||
require.Equal(t, fmt.Sprintf("users/%s", userOne.Username), memoThreeRes.GetCreator())
|
||||
require.Equal(t, apiv1.Visibility_PROTECTED, memoThreeRes.GetVisibility())
|
||||
require.Equal(t, memoThree.Content, memoThreeRes.GetContent())
|
||||
require.Empty(t, memoThreeRes.Attachments)
|
||||
|
|
@ -237,7 +237,7 @@ func TestListMemos(t *testing.T) {
|
|||
// verify memoThree's reactions
|
||||
require.Len(t, memoThreeRes.Reactions, 2)
|
||||
// userOne's reaction
|
||||
userOneReactionIdx := slices.IndexFunc(memoThreeRes.Reactions, func(r *apiv1.Reaction) bool { return r.GetCreator() == fmt.Sprintf("users/%d", userOne.ID) })
|
||||
userOneReactionIdx := slices.IndexFunc(memoThreeRes.Reactions, func(r *apiv1.Reaction) bool { return r.GetCreator() == fmt.Sprintf("users/%s", userOne.Username) })
|
||||
require.NotEqual(t, userOneReactionIdx, -1)
|
||||
|
||||
userOneReaction := memoThreeRes.Reactions[userOneReactionIdx]
|
||||
|
|
@ -245,7 +245,7 @@ func TestListMemos(t *testing.T) {
|
|||
require.Equal(t, "❤️", userOneReaction.ReactionType)
|
||||
|
||||
// userTwo's reaction
|
||||
userTwoReactionIdx := slices.IndexFunc(memoThreeRes.Reactions, func(r *apiv1.Reaction) bool { return r.GetCreator() == fmt.Sprintf("users/%d", userTwo.ID) })
|
||||
userTwoReactionIdx := slices.IndexFunc(memoThreeRes.Reactions, func(r *apiv1.Reaction) bool { return r.GetCreator() == fmt.Sprintf("users/%s", userTwo.Username) })
|
||||
require.NotEqual(t, userTwoReactionIdx, -1)
|
||||
|
||||
userTwoReaction := memoThreeRes.Reactions[userTwoReactionIdx]
|
||||
|
|
|
|||
|
|
@ -4,12 +4,14 @@ import (
|
|||
"context"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/status"
|
||||
|
||||
apiv1 "github.com/usememos/memos/proto/gen/api/v1"
|
||||
"github.com/usememos/memos/store"
|
||||
)
|
||||
|
||||
func TestDeleteMemoShare_VerifiesShareBelongsToMemo(t *testing.T) {
|
||||
|
|
@ -107,3 +109,107 @@ func TestGetMemoByShare_IncludesReactions(t *testing.T) {
|
|||
require.Equal(t, "👍", sharedMemo.Reactions[0].ReactionType)
|
||||
require.Equal(t, memo.Name, sharedMemo.Reactions[0].ContentId)
|
||||
}
|
||||
|
||||
func TestGetMemoByShare_ReturnsNotFoundForUnknownShare(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
ts := NewTestService(t)
|
||||
defer ts.Cleanup()
|
||||
|
||||
_, err := ts.Service.GetMemoByShare(ctx, &apiv1.GetMemoByShareRequest{
|
||||
ShareId: "missing-share-token",
|
||||
})
|
||||
require.Error(t, err)
|
||||
require.Equal(t, codes.NotFound, status.Code(err))
|
||||
}
|
||||
|
||||
func TestGetMemoByShare_ReturnsNotFoundForExpiredShare(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
ts := NewTestService(t)
|
||||
defer ts.Cleanup()
|
||||
|
||||
user, err := ts.CreateRegularUser(ctx, "share-expired")
|
||||
require.NoError(t, err)
|
||||
userCtx := ts.CreateUserContext(ctx, user.ID)
|
||||
|
||||
memo, err := ts.Service.CreateMemo(userCtx, &apiv1.CreateMemoRequest{
|
||||
Memo: &apiv1.Memo{
|
||||
Content: "memo with expired share",
|
||||
Visibility: apiv1.Visibility_PRIVATE,
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
expiredTs := time.Now().Add(-time.Hour).Unix()
|
||||
expiredShare, err := ts.Store.CreateMemoShare(ctx, &store.MemoShare{
|
||||
UID: "expired-share-token",
|
||||
MemoID: parseMemoIDFromNameForTest(t, ts, memo.Name),
|
||||
CreatorID: user.ID,
|
||||
ExpiresTs: &expiredTs,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = ts.Service.GetMemoByShare(ctx, &apiv1.GetMemoByShareRequest{
|
||||
ShareId: expiredShare.UID,
|
||||
})
|
||||
require.Error(t, err)
|
||||
require.Equal(t, codes.NotFound, status.Code(err))
|
||||
}
|
||||
|
||||
func TestGetMemoByShare_ReturnsNotFoundForArchivedMemo(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
ts := NewTestService(t)
|
||||
defer ts.Cleanup()
|
||||
|
||||
user, err := ts.CreateRegularUser(ctx, "share-archived")
|
||||
require.NoError(t, err)
|
||||
userCtx := ts.CreateUserContext(ctx, user.ID)
|
||||
|
||||
memoResp, err := ts.Service.CreateMemo(userCtx, &apiv1.CreateMemoRequest{
|
||||
Memo: &apiv1.Memo{
|
||||
Content: "memo that will be archived",
|
||||
Visibility: apiv1.Visibility_PRIVATE,
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
share, err := ts.Service.CreateMemoShare(userCtx, &apiv1.CreateMemoShareRequest{
|
||||
Parent: memoResp.Name,
|
||||
MemoShare: &apiv1.MemoShare{},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
memoID := parseMemoIDFromNameForTest(t, ts, memoResp.Name)
|
||||
memo, err := ts.Store.GetMemo(ctx, &store.FindMemo{ID: &memoID})
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, memo)
|
||||
|
||||
archived := store.Archived
|
||||
err = ts.Store.UpdateMemo(ctx, &store.UpdateMemo{
|
||||
ID: memo.ID,
|
||||
RowStatus: &archived,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
shareToken := share.Name[strings.LastIndex(share.Name, "/")+1:]
|
||||
_, err = ts.Service.GetMemoByShare(ctx, &apiv1.GetMemoByShareRequest{
|
||||
ShareId: shareToken,
|
||||
})
|
||||
require.Error(t, err)
|
||||
require.Equal(t, codes.NotFound, status.Code(err))
|
||||
}
|
||||
|
||||
func parseMemoIDFromNameForTest(t *testing.T, ts *TestService, memoName string) int32 {
|
||||
t.Helper()
|
||||
|
||||
memoUID, ok := strings.CutPrefix(memoName, "memos/")
|
||||
require.True(t, ok, "memo name must start with memos/: %s", memoName)
|
||||
|
||||
memo, err := ts.Store.GetMemo(context.Background(), &store.FindMemo{UID: &memoUID})
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, memo)
|
||||
|
||||
return memo.ID
|
||||
}
|
||||
|
|
|
|||
|
|
@ -41,6 +41,7 @@ func TestDeleteMemoReaction(t *testing.T) {
|
|||
})
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, reaction)
|
||||
require.Equal(t, "users/user", reaction.Creator)
|
||||
|
||||
// Delete reaction - should succeed
|
||||
_, err = ts.Service.DeleteMemoReaction(userCtx, &apiv1.DeleteMemoReactionRequest{
|
||||
|
|
|
|||
|
|
@ -27,7 +27,7 @@ func TestListShortcuts(t *testing.T) {
|
|||
|
||||
// List shortcuts (should be empty initially)
|
||||
req := &v1pb.ListShortcutsRequest{
|
||||
Parent: fmt.Sprintf("users/%d", user.ID),
|
||||
Parent: fmt.Sprintf("users/%s", user.Username),
|
||||
}
|
||||
|
||||
resp, err := ts.Service.ListShortcuts(userCtx, req)
|
||||
|
|
@ -50,7 +50,7 @@ func TestListShortcuts(t *testing.T) {
|
|||
userCtx := ts.CreateUserContext(ctx, user1.ID)
|
||||
|
||||
req := &v1pb.ListShortcutsRequest{
|
||||
Parent: fmt.Sprintf("users/%d", user2.ID),
|
||||
Parent: fmt.Sprintf("users/%s", user2.Username),
|
||||
}
|
||||
|
||||
_, err = ts.Service.ListShortcuts(userCtx, req)
|
||||
|
|
@ -82,14 +82,33 @@ func TestListShortcuts(t *testing.T) {
|
|||
ts := NewTestService(t)
|
||||
defer ts.Cleanup()
|
||||
|
||||
_, err := ts.CreateRegularUser(ctx, "testuser")
|
||||
require.NoError(t, err)
|
||||
|
||||
req := &v1pb.ListShortcutsRequest{
|
||||
Parent: "users/1",
|
||||
Parent: "users/testuser",
|
||||
}
|
||||
|
||||
_, err := ts.Service.ListShortcuts(ctx, req)
|
||||
_, err = ts.Service.ListShortcuts(ctx, req)
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "permission denied")
|
||||
})
|
||||
|
||||
t.Run("ListShortcuts rejects numeric parent", func(t *testing.T) {
|
||||
ts := NewTestService(t)
|
||||
defer ts.Cleanup()
|
||||
|
||||
user, err := ts.CreateRegularUser(ctx, "testuser")
|
||||
require.NoError(t, err)
|
||||
|
||||
userCtx := ts.CreateUserContext(ctx, user.ID)
|
||||
|
||||
_, err = ts.Service.ListShortcuts(userCtx, &v1pb.ListShortcutsRequest{
|
||||
Parent: "users/1",
|
||||
})
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "invalid user name")
|
||||
})
|
||||
}
|
||||
|
||||
func TestGetShortcut(t *testing.T) {
|
||||
|
|
@ -108,7 +127,7 @@ func TestGetShortcut(t *testing.T) {
|
|||
|
||||
// First create a shortcut
|
||||
createReq := &v1pb.CreateShortcutRequest{
|
||||
Parent: fmt.Sprintf("users/%d", user.ID),
|
||||
Parent: fmt.Sprintf("users/%s", user.Username),
|
||||
Shortcut: &v1pb.Shortcut{
|
||||
Title: "Test Shortcut",
|
||||
Filter: "tag in [\"test\"]",
|
||||
|
|
@ -144,7 +163,7 @@ func TestGetShortcut(t *testing.T) {
|
|||
// Create shortcut as user1
|
||||
user1Ctx := ts.CreateUserContext(ctx, user1.ID)
|
||||
createReq := &v1pb.CreateShortcutRequest{
|
||||
Parent: fmt.Sprintf("users/%d", user1.ID),
|
||||
Parent: fmt.Sprintf("users/%s", user1.Username),
|
||||
Shortcut: &v1pb.Shortcut{
|
||||
Title: "User1 Shortcut",
|
||||
Filter: "tag in [\"user1\"]",
|
||||
|
|
@ -197,7 +216,7 @@ func TestGetShortcut(t *testing.T) {
|
|||
userCtx := ts.CreateUserContext(ctx, user.ID)
|
||||
|
||||
req := &v1pb.GetShortcutRequest{
|
||||
Name: fmt.Sprintf("users/%d", user.ID) + "/shortcuts/nonexistent",
|
||||
Name: fmt.Sprintf("users/%s", user.Username) + "/shortcuts/nonexistent",
|
||||
}
|
||||
|
||||
_, err = ts.Service.GetShortcut(userCtx, req)
|
||||
|
|
@ -221,7 +240,7 @@ func TestCreateShortcut(t *testing.T) {
|
|||
userCtx := ts.CreateUserContext(ctx, user.ID)
|
||||
|
||||
req := &v1pb.CreateShortcutRequest{
|
||||
Parent: fmt.Sprintf("users/%d", user.ID),
|
||||
Parent: fmt.Sprintf("users/%s", user.Username),
|
||||
Shortcut: &v1pb.Shortcut{
|
||||
Title: "My Shortcut",
|
||||
Filter: "tag in [\"important\"]",
|
||||
|
|
@ -233,11 +252,11 @@ func TestCreateShortcut(t *testing.T) {
|
|||
require.NotNil(t, resp)
|
||||
require.Equal(t, "My Shortcut", resp.Title)
|
||||
require.Equal(t, "tag in [\"important\"]", resp.Filter)
|
||||
require.Contains(t, resp.Name, fmt.Sprintf("users/%d/shortcuts/", user.ID))
|
||||
require.Contains(t, resp.Name, fmt.Sprintf("users/%s/shortcuts/", user.Username))
|
||||
|
||||
// Verify the shortcut was created by listing
|
||||
listReq := &v1pb.ListShortcutsRequest{
|
||||
Parent: fmt.Sprintf("users/%d", user.ID),
|
||||
Parent: fmt.Sprintf("users/%s", user.Username),
|
||||
}
|
||||
|
||||
listResp, err := ts.Service.ListShortcuts(userCtx, listReq)
|
||||
|
|
@ -260,7 +279,7 @@ func TestCreateShortcut(t *testing.T) {
|
|||
userCtx := ts.CreateUserContext(ctx, user1.ID)
|
||||
|
||||
req := &v1pb.CreateShortcutRequest{
|
||||
Parent: fmt.Sprintf("users/%d", user2.ID),
|
||||
Parent: fmt.Sprintf("users/%s", user2.Username),
|
||||
Shortcut: &v1pb.Shortcut{
|
||||
Title: "Forbidden Shortcut",
|
||||
Filter: "tag in [\"forbidden\"]",
|
||||
|
|
@ -308,7 +327,7 @@ func TestCreateShortcut(t *testing.T) {
|
|||
userCtx := ts.CreateUserContext(ctx, user.ID)
|
||||
|
||||
req := &v1pb.CreateShortcutRequest{
|
||||
Parent: fmt.Sprintf("users/%d", user.ID),
|
||||
Parent: fmt.Sprintf("users/%s", user.Username),
|
||||
Shortcut: &v1pb.Shortcut{
|
||||
Title: "Invalid Filter Shortcut",
|
||||
Filter: "invalid||filter))syntax",
|
||||
|
|
@ -332,7 +351,7 @@ func TestCreateShortcut(t *testing.T) {
|
|||
userCtx := ts.CreateUserContext(ctx, user.ID)
|
||||
|
||||
req := &v1pb.CreateShortcutRequest{
|
||||
Parent: fmt.Sprintf("users/%d", user.ID),
|
||||
Parent: fmt.Sprintf("users/%s", user.Username),
|
||||
Shortcut: &v1pb.Shortcut{
|
||||
Filter: "tag in [\"test\"]",
|
||||
},
|
||||
|
|
@ -360,7 +379,7 @@ func TestUpdateShortcut(t *testing.T) {
|
|||
|
||||
// Create a shortcut first
|
||||
createReq := &v1pb.CreateShortcutRequest{
|
||||
Parent: fmt.Sprintf("users/%d", user.ID),
|
||||
Parent: fmt.Sprintf("users/%s", user.Username),
|
||||
Shortcut: &v1pb.Shortcut{
|
||||
Title: "Original Title",
|
||||
Filter: "tag in [\"original\"]",
|
||||
|
|
@ -403,7 +422,7 @@ func TestUpdateShortcut(t *testing.T) {
|
|||
// Create shortcut as user1
|
||||
user1Ctx := ts.CreateUserContext(ctx, user1.ID)
|
||||
createReq := &v1pb.CreateShortcutRequest{
|
||||
Parent: fmt.Sprintf("users/%d", user1.ID),
|
||||
Parent: fmt.Sprintf("users/%s", user1.Username),
|
||||
Shortcut: &v1pb.Shortcut{
|
||||
Title: "User1 Shortcut",
|
||||
Filter: "tag in [\"user1\"]",
|
||||
|
|
@ -442,7 +461,7 @@ func TestUpdateShortcut(t *testing.T) {
|
|||
|
||||
req := &v1pb.UpdateShortcutRequest{
|
||||
Shortcut: &v1pb.Shortcut{
|
||||
Name: fmt.Sprintf("users/%d/shortcuts/test", user.ID),
|
||||
Name: fmt.Sprintf("users/%s/shortcuts/test", user.Username),
|
||||
Title: "Updated Title",
|
||||
},
|
||||
}
|
||||
|
|
@ -484,7 +503,7 @@ func TestUpdateShortcut(t *testing.T) {
|
|||
|
||||
// Create a shortcut first
|
||||
createReq := &v1pb.CreateShortcutRequest{
|
||||
Parent: fmt.Sprintf("users/%d", user.ID),
|
||||
Parent: fmt.Sprintf("users/%s", user.Username),
|
||||
Shortcut: &v1pb.Shortcut{
|
||||
Title: "Test Shortcut",
|
||||
Filter: "tag in [\"test\"]",
|
||||
|
|
@ -527,7 +546,7 @@ func TestDeleteShortcut(t *testing.T) {
|
|||
|
||||
// Create a shortcut first
|
||||
createReq := &v1pb.CreateShortcutRequest{
|
||||
Parent: fmt.Sprintf("users/%d", user.ID),
|
||||
Parent: fmt.Sprintf("users/%s", user.Username),
|
||||
Shortcut: &v1pb.Shortcut{
|
||||
Title: "Shortcut to Delete",
|
||||
Filter: "tag in [\"delete\"]",
|
||||
|
|
@ -547,7 +566,7 @@ func TestDeleteShortcut(t *testing.T) {
|
|||
|
||||
// Verify deletion by listing shortcuts
|
||||
listReq := &v1pb.ListShortcutsRequest{
|
||||
Parent: fmt.Sprintf("users/%d", user.ID),
|
||||
Parent: fmt.Sprintf("users/%s", user.Username),
|
||||
}
|
||||
|
||||
listResp, err := ts.Service.ListShortcuts(userCtx, listReq)
|
||||
|
|
@ -577,7 +596,7 @@ func TestDeleteShortcut(t *testing.T) {
|
|||
// Create shortcut as user1
|
||||
user1Ctx := ts.CreateUserContext(ctx, user1.ID)
|
||||
createReq := &v1pb.CreateShortcutRequest{
|
||||
Parent: fmt.Sprintf("users/%d", user1.ID),
|
||||
Parent: fmt.Sprintf("users/%s", user1.Username),
|
||||
Shortcut: &v1pb.Shortcut{
|
||||
Title: "User1 Shortcut",
|
||||
Filter: "tag in [\"user1\"]",
|
||||
|
|
@ -623,7 +642,7 @@ func TestDeleteShortcut(t *testing.T) {
|
|||
userCtx := ts.CreateUserContext(ctx, user.ID)
|
||||
|
||||
req := &v1pb.DeleteShortcutRequest{
|
||||
Name: fmt.Sprintf("users/%d", user.ID) + "/shortcuts/nonexistent",
|
||||
Name: fmt.Sprintf("users/%s", user.Username) + "/shortcuts/nonexistent",
|
||||
}
|
||||
|
||||
_, err = ts.Service.DeleteShortcut(userCtx, req)
|
||||
|
|
@ -660,7 +679,7 @@ func TestShortcutFiltering(t *testing.T) {
|
|||
|
||||
for i, filter := range validFilters {
|
||||
req := &v1pb.CreateShortcutRequest{
|
||||
Parent: fmt.Sprintf("users/%d", user.ID),
|
||||
Parent: fmt.Sprintf("users/%s", user.Username),
|
||||
Shortcut: &v1pb.Shortcut{
|
||||
Title: "Valid Filter " + string(rune(i)),
|
||||
Filter: filter,
|
||||
|
|
@ -697,7 +716,7 @@ func TestShortcutFiltering(t *testing.T) {
|
|||
|
||||
for _, filter := range invalidFilters {
|
||||
req := &v1pb.CreateShortcutRequest{
|
||||
Parent: fmt.Sprintf("users/%d", user.ID),
|
||||
Parent: fmt.Sprintf("users/%s", user.Username),
|
||||
Shortcut: &v1pb.Shortcut{
|
||||
Title: "Invalid Filter Test",
|
||||
Filter: filter,
|
||||
|
|
@ -727,7 +746,7 @@ func TestShortcutCRUDComplete(t *testing.T) {
|
|||
|
||||
// 1. Create multiple shortcuts
|
||||
shortcut1Req := &v1pb.CreateShortcutRequest{
|
||||
Parent: fmt.Sprintf("users/%d", user.ID),
|
||||
Parent: fmt.Sprintf("users/%s", user.Username),
|
||||
Shortcut: &v1pb.Shortcut{
|
||||
Title: "Work Notes",
|
||||
Filter: "tag in [\"work\"]",
|
||||
|
|
@ -735,7 +754,7 @@ func TestShortcutCRUDComplete(t *testing.T) {
|
|||
}
|
||||
|
||||
shortcut2Req := &v1pb.CreateShortcutRequest{
|
||||
Parent: fmt.Sprintf("users/%d", user.ID),
|
||||
Parent: fmt.Sprintf("users/%s", user.Username),
|
||||
Shortcut: &v1pb.Shortcut{
|
||||
Title: "Personal Notes",
|
||||
Filter: "tag in [\"personal\"]",
|
||||
|
|
@ -752,7 +771,7 @@ func TestShortcutCRUDComplete(t *testing.T) {
|
|||
|
||||
// 2. List shortcuts and verify both exist
|
||||
listReq := &v1pb.ListShortcutsRequest{
|
||||
Parent: fmt.Sprintf("users/%d", user.ID),
|
||||
Parent: fmt.Sprintf("users/%s", user.Username),
|
||||
}
|
||||
|
||||
listResp, err := ts.Service.ListShortcuts(userCtx, listReq)
|
||||
|
|
|
|||
|
|
@ -0,0 +1,113 @@
|
|||
package test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
apiv1 "github.com/usememos/memos/proto/gen/api/v1"
|
||||
)
|
||||
|
||||
func TestUserEmailVisibility(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
t.Run("GetUser redacts email for anonymous callers", func(t *testing.T) {
|
||||
ts := NewTestService(t)
|
||||
defer ts.Cleanup()
|
||||
|
||||
user, err := ts.CreateRegularUser(ctx, "targetuser")
|
||||
require.NoError(t, err)
|
||||
|
||||
got, err := ts.Service.GetUser(ctx, &apiv1.GetUserRequest{
|
||||
Name: "users/targetuser",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, got)
|
||||
require.Equal(t, user.Username, got.Username)
|
||||
require.Empty(t, got.Email)
|
||||
})
|
||||
|
||||
t.Run("GetUser redacts email for other regular users", func(t *testing.T) {
|
||||
ts := NewTestService(t)
|
||||
defer ts.Cleanup()
|
||||
|
||||
targetUser, err := ts.CreateRegularUser(ctx, "targetuser")
|
||||
require.NoError(t, err)
|
||||
viewer, err := ts.CreateRegularUser(ctx, "vieweruser")
|
||||
require.NoError(t, err)
|
||||
|
||||
viewerCtx := ts.CreateUserContext(ctx, viewer.ID)
|
||||
got, err := ts.Service.GetUser(viewerCtx, &apiv1.GetUserRequest{
|
||||
Name: "users/targetuser",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, got)
|
||||
require.Equal(t, targetUser.Username, got.Username)
|
||||
require.Empty(t, got.Email)
|
||||
})
|
||||
|
||||
t.Run("GetUser returns email for the same user", func(t *testing.T) {
|
||||
ts := NewTestService(t)
|
||||
defer ts.Cleanup()
|
||||
|
||||
user, err := ts.CreateRegularUser(ctx, "selfuser")
|
||||
require.NoError(t, err)
|
||||
|
||||
userCtx := ts.CreateUserContext(ctx, user.ID)
|
||||
got, err := ts.Service.GetUser(userCtx, &apiv1.GetUserRequest{
|
||||
Name: "users/selfuser",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, got)
|
||||
require.Equal(t, user.Email, got.Email)
|
||||
})
|
||||
|
||||
t.Run("GetUser returns email for admins", func(t *testing.T) {
|
||||
ts := NewTestService(t)
|
||||
defer ts.Cleanup()
|
||||
|
||||
targetUser, err := ts.CreateRegularUser(ctx, "targetuser")
|
||||
require.NoError(t, err)
|
||||
admin, err := ts.CreateHostUser(ctx, "admin")
|
||||
require.NoError(t, err)
|
||||
|
||||
adminCtx := ts.CreateUserContext(ctx, admin.ID)
|
||||
got, err := ts.Service.GetUser(adminCtx, &apiv1.GetUserRequest{
|
||||
Name: "users/targetuser",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, got)
|
||||
require.Equal(t, targetUser.Email, got.Email)
|
||||
})
|
||||
|
||||
t.Run("GetCurrentUser returns email for the authenticated user", func(t *testing.T) {
|
||||
ts := NewTestService(t)
|
||||
defer ts.Cleanup()
|
||||
|
||||
user, err := ts.CreateRegularUser(ctx, "currentuser")
|
||||
require.NoError(t, err)
|
||||
|
||||
userCtx := ts.CreateUserContext(ctx, user.ID)
|
||||
got, err := ts.Service.GetCurrentUser(userCtx, &apiv1.GetCurrentUserRequest{})
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, got)
|
||||
require.NotNil(t, got.User)
|
||||
require.Equal(t, user.Email, got.User.Email)
|
||||
})
|
||||
|
||||
t.Run("GetInstanceProfile redacts admin email for anonymous callers", func(t *testing.T) {
|
||||
ts := NewTestService(t)
|
||||
defer ts.Cleanup()
|
||||
|
||||
admin, err := ts.CreateHostUser(ctx, "admin")
|
||||
require.NoError(t, err)
|
||||
|
||||
got, err := ts.Service.GetInstanceProfile(ctx, &apiv1.GetInstanceProfileRequest{})
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, got)
|
||||
require.NotNil(t, got.Admin)
|
||||
require.Equal(t, admin.Username, got.Admin.Username)
|
||||
require.Empty(t, got.Admin.Email)
|
||||
})
|
||||
}
|
||||
|
|
@ -43,12 +43,14 @@ func TestListUserNotificationsIncludesMemoCommentPayload(t *testing.T) {
|
|||
require.NoError(t, err)
|
||||
|
||||
resp, err := ts.Service.ListUserNotifications(ownerCtx, &apiv1.ListUserNotificationsRequest{
|
||||
Parent: fmt.Sprintf("users/%d", owner.ID),
|
||||
Parent: fmt.Sprintf("users/%s", owner.Username),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Len(t, resp.Notifications, 1)
|
||||
|
||||
notification := resp.Notifications[0]
|
||||
require.Contains(t, notification.Name, fmt.Sprintf("users/%s/notifications/", owner.Username))
|
||||
require.Equal(t, fmt.Sprintf("users/%s", commenter.Username), notification.Sender)
|
||||
require.Equal(t, apiv1.UserNotification_MEMO_COMMENT, notification.Type)
|
||||
require.NotNil(t, notification.GetMemoComment())
|
||||
require.Equal(t, comment.Name, notification.GetMemoComment().Memo)
|
||||
|
|
@ -134,10 +136,26 @@ func TestListUserNotificationsOmitsPayloadWhenMemosDeleted(t *testing.T) {
|
|||
require.NoError(t, err)
|
||||
|
||||
resp, err := ts.Service.ListUserNotifications(ownerCtx, &apiv1.ListUserNotificationsRequest{
|
||||
Parent: fmt.Sprintf("users/%d", owner.ID),
|
||||
Parent: fmt.Sprintf("users/%s", owner.Username),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Len(t, resp.Notifications, 1)
|
||||
require.Equal(t, apiv1.UserNotification_MEMO_COMMENT, resp.Notifications[0].Type)
|
||||
require.Nil(t, resp.Notifications[0].GetMemoComment())
|
||||
}
|
||||
|
||||
func TestListUserNotificationsRejectsNumericParent(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
ts := NewTestService(t)
|
||||
defer ts.Cleanup()
|
||||
|
||||
owner, err := ts.CreateRegularUser(ctx, "notification-owner")
|
||||
require.NoError(t, err)
|
||||
ownerCtx := ts.CreateUserContext(ctx, owner.ID)
|
||||
|
||||
_, err = ts.Service.ListUserNotifications(ownerCtx, &apiv1.ListUserNotificationsRequest{
|
||||
Parent: "users/1",
|
||||
})
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "invalid user name")
|
||||
}
|
||||
|
|
|
|||
|
|
@ -0,0 +1,60 @@
|
|||
package test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
apiv1 "github.com/usememos/memos/proto/gen/api/v1"
|
||||
)
|
||||
|
||||
func TestUserResourceName(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
t.Run("GetUser returns username-based canonical name", func(t *testing.T) {
|
||||
ts := NewTestService(t)
|
||||
defer ts.Cleanup()
|
||||
|
||||
user, err := ts.CreateRegularUser(ctx, "testuser")
|
||||
require.NoError(t, err)
|
||||
|
||||
got, err := ts.Service.GetUser(ctx, &apiv1.GetUserRequest{
|
||||
Name: "users/testuser",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, got)
|
||||
require.Equal(t, "users/testuser", got.Name)
|
||||
require.Equal(t, user.Username, got.Username)
|
||||
})
|
||||
|
||||
t.Run("CreateUser returns username-based canonical name", func(t *testing.T) {
|
||||
ts := NewTestService(t)
|
||||
defer ts.Cleanup()
|
||||
|
||||
created, err := ts.Service.CreateUser(ctx, &apiv1.CreateUserRequest{
|
||||
User: &apiv1.User{
|
||||
Username: "newuser",
|
||||
Email: "newuser@example.com",
|
||||
Password: "password123",
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, created)
|
||||
require.Equal(t, "users/newuser", created.Name)
|
||||
})
|
||||
|
||||
t.Run("GetUser rejects numeric user resource names", func(t *testing.T) {
|
||||
ts := NewTestService(t)
|
||||
defer ts.Cleanup()
|
||||
|
||||
_, err := ts.CreateRegularUser(ctx, "testuser")
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = ts.Service.GetUser(ctx, &apiv1.GetUserRequest{
|
||||
Name: "users/1",
|
||||
})
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "invalid user name")
|
||||
})
|
||||
}
|
||||
|
|
@ -143,6 +143,7 @@ func TestCreateUserRegistration(t *testing.T) {
|
|||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "users/newadmin", createdUser.Name)
|
||||
require.NotNil(t, createdUser)
|
||||
require.Equal(t, apiv1.User_ADMIN, createdUser.Role)
|
||||
})
|
||||
|
|
@ -168,6 +169,7 @@ func TestCreateUserRegistration(t *testing.T) {
|
|||
})
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, createdUser)
|
||||
require.Equal(t, "users/wannabeadmin", createdUser.Name)
|
||||
require.Equal(t, apiv1.User_USER, createdUser.Role, "Unauthenticated users can only create USER role")
|
||||
})
|
||||
}
|
||||
|
|
|
|||
|
|
@ -20,7 +20,7 @@ func TestGetUserStats_TagCount(t *testing.T) {
|
|||
defer ts.Cleanup()
|
||||
|
||||
// Create a test host user
|
||||
user, err := ts.CreateHostUser(ctx, "test_user")
|
||||
user, err := ts.CreateHostUser(ctx, "test-user")
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create user context for authentication
|
||||
|
|
@ -40,12 +40,13 @@ func TestGetUserStats_TagCount(t *testing.T) {
|
|||
require.NotNil(t, memo)
|
||||
|
||||
// Test GetUserStats
|
||||
userName := fmt.Sprintf("users/%d", user.ID)
|
||||
userName := fmt.Sprintf("users/%s", user.Username)
|
||||
response, err := ts.Service.GetUserStats(userCtx, &v1pb.GetUserStatsRequest{
|
||||
Name: userName,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, response)
|
||||
require.Equal(t, fmt.Sprintf("users/%s/stats", user.Username), response.Name)
|
||||
|
||||
// Check that the tag count is exactly 1, not 2
|
||||
require.Contains(t, response.TagCount, "test")
|
||||
|
|
@ -102,4 +103,10 @@ func TestGetUserStats_TagCount(t *testing.T) {
|
|||
// The original test tag should still be 2
|
||||
require.Contains(t, response3.TagCount, "test")
|
||||
require.Equal(t, int32(2), response3.TagCount["test"], "Original tag count should remain 2")
|
||||
|
||||
_, err = ts.Service.GetUserStats(userCtx, &v1pb.GetUserStatsRequest{
|
||||
Name: "users/1",
|
||||
})
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "invalid user name")
|
||||
}
|
||||
|
|
|
|||
|
|
@ -0,0 +1,49 @@
|
|||
package v1
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
|
||||
"github.com/usememos/memos/internal/base"
|
||||
"github.com/usememos/memos/store"
|
||||
)
|
||||
|
||||
// BuildUserName returns the canonical public resource name for a user.
|
||||
func BuildUserName(username string) string {
|
||||
return UserNamePrefix + username
|
||||
}
|
||||
|
||||
// ExtractUsernameFromName extracts the username token from a user resource name.
|
||||
func ExtractUsernameFromName(name string) (string, error) {
|
||||
tokens, err := GetNameParentTokens(name, UserNamePrefix)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
username := tokens[0]
|
||||
if username == "" {
|
||||
return "", errors.Errorf("invalid user name %q", name)
|
||||
}
|
||||
if _, err := strconv.ParseInt(username, 10, 32); err == nil {
|
||||
return "", errors.Errorf("invalid username %q", username)
|
||||
}
|
||||
if username != strings.ToLower(username) || !base.UIDMatcher.MatchString(username) {
|
||||
return "", errors.Errorf("invalid username %q", username)
|
||||
}
|
||||
return username, nil
|
||||
}
|
||||
|
||||
// ResolveUserByName resolves a username-based user resource name to a store user.
|
||||
func ResolveUserByName(ctx context.Context, stores *store.Store, name string) (*store.User, error) {
|
||||
username, err := ExtractUsernameFromName(name)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
user, err := stores.GetUser(ctx, &store.FindUser{Username: &username})
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "resolve user by name: GetUser failed")
|
||||
}
|
||||
return user, nil
|
||||
}
|
||||
|
|
@ -64,42 +64,21 @@ func (s *APIV1Service) ListUsers(ctx context.Context, request *v1pb.ListUsersReq
|
|||
TotalSize: int32(len(users)),
|
||||
}
|
||||
for _, user := range users {
|
||||
response.Users = append(response.Users, convertUserFromStore(user))
|
||||
response.Users = append(response.Users, convertUserFromStore(user, currentUser))
|
||||
}
|
||||
return response, nil
|
||||
}
|
||||
|
||||
func (s *APIV1Service) GetUser(ctx context.Context, request *v1pb.GetUserRequest) (*v1pb.User, error) {
|
||||
// Extract identifier from "users/{id_or_username}"
|
||||
identifier := extractUserIdentifierFromName(request.Name)
|
||||
if identifier == "" {
|
||||
return nil, status.Errorf(codes.InvalidArgument, "invalid user name: %s", request.Name)
|
||||
}
|
||||
|
||||
var user *store.User
|
||||
var err error
|
||||
|
||||
// Try to parse as numeric ID first
|
||||
if userID, parseErr := strconv.ParseInt(identifier, 10, 32); parseErr == nil {
|
||||
// It's a numeric ID
|
||||
userID32 := int32(userID)
|
||||
user, err = s.Store.GetUser(ctx, &store.FindUser{
|
||||
ID: &userID32,
|
||||
})
|
||||
} else {
|
||||
// It's a username
|
||||
user, err = s.Store.GetUser(ctx, &store.FindUser{
|
||||
Username: &identifier,
|
||||
})
|
||||
}
|
||||
|
||||
user, err := ResolveUserByName(ctx, s.Store, request.Name)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to get user: %v", err)
|
||||
return nil, status.Errorf(codes.InvalidArgument, "invalid user name: %s", request.Name)
|
||||
}
|
||||
if user == nil {
|
||||
return nil, status.Errorf(codes.NotFound, "user not found")
|
||||
}
|
||||
return convertUserFromStore(user), nil
|
||||
currentUser, _ := s.fetchCurrentUser(ctx)
|
||||
return convertUserFromStore(user, currentUser), nil
|
||||
}
|
||||
|
||||
func (s *APIV1Service) CreateUser(ctx context.Context, request *v1pb.CreateUserRequest) (*v1pb.User, error) {
|
||||
|
|
@ -176,17 +155,24 @@ func (s *APIV1Service) CreateUser(ctx context.Context, request *v1pb.CreateUserR
|
|||
return nil, status.Errorf(codes.Internal, "failed to create user: %v", err)
|
||||
}
|
||||
|
||||
return convertUserFromStore(user), nil
|
||||
return convertUserFromStore(user, user), nil
|
||||
}
|
||||
|
||||
func (s *APIV1Service) UpdateUser(ctx context.Context, request *v1pb.UpdateUserRequest) (*v1pb.User, error) {
|
||||
if request.UpdateMask == nil || len(request.UpdateMask.Paths) == 0 {
|
||||
return nil, status.Errorf(codes.InvalidArgument, "update mask is empty")
|
||||
}
|
||||
userID, err := ExtractUserIDFromName(request.User.Name)
|
||||
user, err := ResolveUserByName(ctx, s.Store, request.User.Name)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.InvalidArgument, "invalid user name: %v", err)
|
||||
}
|
||||
if user == nil {
|
||||
if request.AllowMissing {
|
||||
return nil, status.Errorf(codes.NotFound, "user not found")
|
||||
}
|
||||
return nil, status.Errorf(codes.NotFound, "user not found")
|
||||
}
|
||||
userID := user.ID
|
||||
currentUser, err := s.fetchCurrentUser(ctx)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to get user: %v", err)
|
||||
|
|
@ -200,19 +186,6 @@ func (s *APIV1Service) UpdateUser(ctx context.Context, request *v1pb.UpdateUserR
|
|||
return nil, status.Errorf(codes.PermissionDenied, "permission denied")
|
||||
}
|
||||
|
||||
user, err := s.Store.GetUser(ctx, &store.FindUser{ID: &userID})
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to get user: %v", err)
|
||||
}
|
||||
if user == nil {
|
||||
// Handle allow_missing field
|
||||
if request.AllowMissing {
|
||||
// Could create user if missing, but for now return not found
|
||||
return nil, status.Errorf(codes.NotFound, "user not found")
|
||||
}
|
||||
return nil, status.Errorf(codes.NotFound, "user not found")
|
||||
}
|
||||
|
||||
currentTs := time.Now().Unix()
|
||||
update := &store.UpdateUser{
|
||||
ID: user.ID,
|
||||
|
|
@ -288,14 +261,18 @@ func (s *APIV1Service) UpdateUser(ctx context.Context, request *v1pb.UpdateUserR
|
|||
return nil, status.Errorf(codes.Internal, "failed to update user: %v", err)
|
||||
}
|
||||
|
||||
return convertUserFromStore(updatedUser), nil
|
||||
return convertUserFromStore(updatedUser, currentUser), nil
|
||||
}
|
||||
|
||||
func (s *APIV1Service) DeleteUser(ctx context.Context, request *v1pb.DeleteUserRequest) (*emptypb.Empty, error) {
|
||||
userID, err := ExtractUserIDFromName(request.Name)
|
||||
user, err := ResolveUserByName(ctx, s.Store, request.Name)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.InvalidArgument, "invalid user name: %v", err)
|
||||
}
|
||||
if user == nil {
|
||||
return nil, status.Errorf(codes.NotFound, "user not found")
|
||||
}
|
||||
userID := user.ID
|
||||
currentUser, err := s.fetchCurrentUser(ctx)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to get user: %v", err)
|
||||
|
|
@ -307,14 +284,6 @@ func (s *APIV1Service) DeleteUser(ctx context.Context, request *v1pb.DeleteUserR
|
|||
return nil, status.Errorf(codes.PermissionDenied, "permission denied")
|
||||
}
|
||||
|
||||
user, err := s.Store.GetUser(ctx, &store.FindUser{ID: &userID})
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to get user: %v", err)
|
||||
}
|
||||
if user == nil {
|
||||
return nil, status.Errorf(codes.NotFound, "user not found")
|
||||
}
|
||||
|
||||
if err := s.Store.DeleteUser(ctx, &store.DeleteUser{
|
||||
ID: user.ID,
|
||||
}); err != nil {
|
||||
|
|
@ -332,12 +301,69 @@ func getDefaultUserGeneralSetting() *v1pb.UserSetting_GeneralSetting {
|
|||
}
|
||||
}
|
||||
|
||||
func (s *APIV1Service) resolveUserFromName(ctx context.Context, name string) (*store.User, error) {
|
||||
user, err := ResolveUserByName(ctx, s.Store, name)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if user == nil {
|
||||
return nil, errors.Errorf("user not found: %s", name)
|
||||
}
|
||||
return user, nil
|
||||
}
|
||||
|
||||
func (s *APIV1Service) resolveUserAndSettingKeyFromName(ctx context.Context, name string) (*store.User, string, error) {
|
||||
parts := strings.Split(name, "/")
|
||||
if len(parts) != 4 || parts[0] != "users" || parts[2] != "settings" {
|
||||
return nil, "", errors.Errorf("invalid resource name format: %s", name)
|
||||
}
|
||||
|
||||
user, err := s.resolveUserFromName(ctx, BuildUserName(parts[1]))
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
return user, parts[3], nil
|
||||
}
|
||||
|
||||
func (s *APIV1Service) resolveUserAndWebhookIDFromName(ctx context.Context, name string) (*store.User, string, error) {
|
||||
parts := strings.Split(name, "/")
|
||||
if len(parts) != 4 || parts[0] != "users" || parts[2] != "webhooks" {
|
||||
return nil, "", errors.New("invalid webhook name format")
|
||||
}
|
||||
|
||||
user, err := s.resolveUserFromName(ctx, BuildUserName(parts[1]))
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
return user, parts[3], nil
|
||||
}
|
||||
|
||||
func (s *APIV1Service) resolveUserAndNotificationIDFromName(ctx context.Context, name string) (*store.User, int32, error) {
|
||||
parts := strings.Split(name, "/")
|
||||
if len(parts) != 4 || parts[0] != "users" || parts[2] != "notifications" {
|
||||
return nil, 0, errors.Errorf("invalid notification name: %s", name)
|
||||
}
|
||||
|
||||
user, err := s.resolveUserFromName(ctx, BuildUserName(parts[1]))
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
id, err := strconv.Atoi(parts[3])
|
||||
if err != nil {
|
||||
return nil, 0, errors.Errorf("invalid notification id: %s", parts[3])
|
||||
}
|
||||
|
||||
return user, int32(id), nil
|
||||
}
|
||||
|
||||
func (s *APIV1Service) GetUserSetting(ctx context.Context, request *v1pb.GetUserSettingRequest) (*v1pb.UserSetting, error) {
|
||||
// Parse resource name: users/{user}/settings/{setting}
|
||||
userID, settingKey, err := ExtractUserIDAndSettingKeyFromName(request.Name)
|
||||
user, settingKey, err := s.resolveUserAndSettingKeyFromName(ctx, request.Name)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.InvalidArgument, "invalid resource name: %v", err)
|
||||
}
|
||||
userID := user.ID
|
||||
|
||||
currentUser, err := s.fetchCurrentUser(ctx)
|
||||
if err != nil {
|
||||
|
|
@ -366,15 +392,16 @@ func (s *APIV1Service) GetUserSetting(ctx context.Context, request *v1pb.GetUser
|
|||
return nil, status.Errorf(codes.Internal, "failed to get user setting: %v", err)
|
||||
}
|
||||
|
||||
return convertUserSettingFromStore(userSetting, userID, storeKey), nil
|
||||
return convertUserSettingFromStore(userSetting, user, storeKey), nil
|
||||
}
|
||||
|
||||
func (s *APIV1Service) UpdateUserSetting(ctx context.Context, request *v1pb.UpdateUserSettingRequest) (*v1pb.UserSetting, error) {
|
||||
// Parse resource name: users/{user}/settings/{setting}
|
||||
userID, settingKey, err := ExtractUserIDAndSettingKeyFromName(request.Setting.Name)
|
||||
user, settingKey, err := s.resolveUserAndSettingKeyFromName(ctx, request.Setting.Name)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.InvalidArgument, "invalid resource name: %v", err)
|
||||
}
|
||||
userID := user.ID
|
||||
|
||||
currentUser, err := s.fetchCurrentUser(ctx)
|
||||
if err != nil {
|
||||
|
|
@ -461,10 +488,11 @@ func (s *APIV1Service) UpdateUserSetting(ctx context.Context, request *v1pb.Upda
|
|||
}
|
||||
|
||||
func (s *APIV1Service) ListUserSettings(ctx context.Context, request *v1pb.ListUserSettingsRequest) (*v1pb.ListUserSettingsResponse, error) {
|
||||
userID, err := ExtractUserIDFromName(request.Parent)
|
||||
user, err := s.resolveUserFromName(ctx, request.Parent)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.InvalidArgument, "invalid parent name: %v", err)
|
||||
}
|
||||
userID := user.ID
|
||||
|
||||
currentUser, err := s.fetchCurrentUser(ctx)
|
||||
if err != nil {
|
||||
|
|
@ -488,7 +516,7 @@ func (s *APIV1Service) ListUserSettings(ctx context.Context, request *v1pb.ListU
|
|||
|
||||
settings := make([]*v1pb.UserSetting, 0, len(userSettings))
|
||||
for _, storeSetting := range userSettings {
|
||||
apiSetting := convertUserSettingFromStore(storeSetting, userID, storeSetting.Key)
|
||||
apiSetting := convertUserSettingFromStore(storeSetting, user, storeSetting.Key)
|
||||
if apiSetting != nil {
|
||||
settings = append(settings, apiSetting)
|
||||
}
|
||||
|
|
@ -502,7 +530,7 @@ func (s *APIV1Service) ListUserSettings(ctx context.Context, request *v1pb.ListU
|
|||
}
|
||||
if !hasGeneral {
|
||||
defaultGeneral := &v1pb.UserSetting{
|
||||
Name: fmt.Sprintf("users/%d/settings/%s", userID, convertSettingKeyFromStore(storepb.UserSetting_GENERAL)),
|
||||
Name: fmt.Sprintf("%s/settings/%s", BuildUserName(user.Username), convertSettingKeyFromStore(storepb.UserSetting_GENERAL)),
|
||||
Value: &v1pb.UserSetting_GeneralSetting_{
|
||||
GeneralSetting: getDefaultUserGeneralSetting(),
|
||||
},
|
||||
|
|
@ -533,10 +561,11 @@ func (s *APIV1Service) ListUserSettings(ctx context.Context, request *v1pb.ListU
|
|||
// Authentication: Required (session cookie or access token)
|
||||
// Authorization: User can only list their own tokens.
|
||||
func (s *APIV1Service) ListPersonalAccessTokens(ctx context.Context, request *v1pb.ListPersonalAccessTokensRequest) (*v1pb.ListPersonalAccessTokensResponse, error) {
|
||||
userID, err := ExtractUserIDFromName(request.Parent)
|
||||
user, err := s.resolveUserFromName(ctx, request.Parent)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.InvalidArgument, "invalid user name: %v", err)
|
||||
}
|
||||
userID := user.ID
|
||||
|
||||
// Verify permission
|
||||
claims := auth.GetUserClaims(ctx)
|
||||
|
|
@ -555,7 +584,7 @@ func (s *APIV1Service) ListPersonalAccessTokens(ctx context.Context, request *v1
|
|||
personalAccessTokens := make([]*v1pb.PersonalAccessToken, len(tokens))
|
||||
for i, token := range tokens {
|
||||
personalAccessTokens[i] = &v1pb.PersonalAccessToken{
|
||||
Name: fmt.Sprintf("%s/personalAccessTokens/%s", request.Parent, token.TokenId),
|
||||
Name: fmt.Sprintf("%s/personalAccessTokens/%s", BuildUserName(user.Username), token.TokenId),
|
||||
Description: token.Description,
|
||||
ExpiresAt: token.ExpiresAt,
|
||||
CreatedAt: token.CreatedAt,
|
||||
|
|
@ -587,10 +616,11 @@ func (s *APIV1Service) ListPersonalAccessTokens(ctx context.Context, request *v1
|
|||
// Authentication: Required (session cookie or access token)
|
||||
// Authorization: User can only create tokens for themselves.
|
||||
func (s *APIV1Service) CreatePersonalAccessToken(ctx context.Context, request *v1pb.CreatePersonalAccessTokenRequest) (*v1pb.CreatePersonalAccessTokenResponse, error) {
|
||||
userID, err := ExtractUserIDFromName(request.Parent)
|
||||
user, err := s.resolveUserFromName(ctx, request.Parent)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.InvalidArgument, "invalid user name: %v", err)
|
||||
}
|
||||
userID := user.ID
|
||||
|
||||
// Verify permission
|
||||
claims := auth.GetUserClaims(ctx)
|
||||
|
|
@ -625,7 +655,7 @@ func (s *APIV1Service) CreatePersonalAccessToken(ctx context.Context, request *v
|
|||
|
||||
return &v1pb.CreatePersonalAccessTokenResponse{
|
||||
PersonalAccessToken: &v1pb.PersonalAccessToken{
|
||||
Name: fmt.Sprintf("%s/personalAccessTokens/%s", request.Parent, tokenID),
|
||||
Name: fmt.Sprintf("%s/personalAccessTokens/%s", BuildUserName(user.Username), tokenID),
|
||||
Description: request.Description,
|
||||
ExpiresAt: expiresAt,
|
||||
CreatedAt: patRecord.CreatedAt,
|
||||
|
|
@ -648,16 +678,16 @@ func (s *APIV1Service) CreatePersonalAccessToken(ctx context.Context, request *v
|
|||
// Authentication: Required (session cookie or access token)
|
||||
// Authorization: User can only delete their own tokens.
|
||||
func (s *APIV1Service) DeletePersonalAccessToken(ctx context.Context, request *v1pb.DeletePersonalAccessTokenRequest) (*emptypb.Empty, error) {
|
||||
// Parse name: users/{user_id}/personalAccessTokens/{token_id}
|
||||
parts := strings.Split(request.Name, "/")
|
||||
if len(parts) != 4 || parts[0] != "users" || parts[2] != "personalAccessTokens" {
|
||||
return nil, status.Errorf(codes.InvalidArgument, "invalid personal access token name")
|
||||
}
|
||||
|
||||
userID, err := util.ConvertStringToInt32(parts[1])
|
||||
user, err := s.resolveUserFromName(ctx, BuildUserName(parts[1]))
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.InvalidArgument, "invalid user ID: %v", err)
|
||||
return nil, status.Errorf(codes.InvalidArgument, "invalid user name: %v", err)
|
||||
}
|
||||
userID := user.ID
|
||||
tokenID := parts[3]
|
||||
|
||||
// Verify permission
|
||||
|
|
@ -677,10 +707,11 @@ func (s *APIV1Service) DeletePersonalAccessToken(ctx context.Context, request *v
|
|||
}
|
||||
|
||||
func (s *APIV1Service) ListUserWebhooks(ctx context.Context, request *v1pb.ListUserWebhooksRequest) (*v1pb.ListUserWebhooksResponse, error) {
|
||||
userID, err := ExtractUserIDFromName(request.Parent)
|
||||
user, err := s.resolveUserFromName(ctx, request.Parent)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.InvalidArgument, "invalid parent: %v", err)
|
||||
}
|
||||
userID := user.ID
|
||||
|
||||
currentUser, err := s.fetchCurrentUser(ctx)
|
||||
if err != nil {
|
||||
|
|
@ -700,7 +731,7 @@ func (s *APIV1Service) ListUserWebhooks(ctx context.Context, request *v1pb.ListU
|
|||
|
||||
userWebhooks := make([]*v1pb.UserWebhook, 0, len(webhooks))
|
||||
for _, webhook := range webhooks {
|
||||
userWebhooks = append(userWebhooks, convertUserWebhookFromUserSetting(webhook, userID))
|
||||
userWebhooks = append(userWebhooks, convertUserWebhookFromUserSetting(webhook, user))
|
||||
}
|
||||
|
||||
return &v1pb.ListUserWebhooksResponse{
|
||||
|
|
@ -709,10 +740,11 @@ func (s *APIV1Service) ListUserWebhooks(ctx context.Context, request *v1pb.ListU
|
|||
}
|
||||
|
||||
func (s *APIV1Service) CreateUserWebhook(ctx context.Context, request *v1pb.CreateUserWebhookRequest) (*v1pb.UserWebhook, error) {
|
||||
userID, err := ExtractUserIDFromName(request.Parent)
|
||||
user, err := s.resolveUserFromName(ctx, request.Parent)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.InvalidArgument, "invalid parent: %v", err)
|
||||
}
|
||||
userID := user.ID
|
||||
|
||||
currentUser, err := s.fetchCurrentUser(ctx)
|
||||
if err != nil {
|
||||
|
|
@ -744,7 +776,7 @@ func (s *APIV1Service) CreateUserWebhook(ctx context.Context, request *v1pb.Crea
|
|||
return nil, status.Errorf(codes.Internal, "failed to create webhook: %v", err)
|
||||
}
|
||||
|
||||
return convertUserWebhookFromUserSetting(webhook, userID), nil
|
||||
return convertUserWebhookFromUserSetting(webhook, user), nil
|
||||
}
|
||||
|
||||
func (s *APIV1Service) UpdateUserWebhook(ctx context.Context, request *v1pb.UpdateUserWebhookRequest) (*v1pb.UserWebhook, error) {
|
||||
|
|
@ -752,10 +784,11 @@ func (s *APIV1Service) UpdateUserWebhook(ctx context.Context, request *v1pb.Upda
|
|||
return nil, status.Errorf(codes.InvalidArgument, "webhook is required")
|
||||
}
|
||||
|
||||
webhookID, userID, err := parseUserWebhookName(request.Webhook.Name)
|
||||
user, webhookID, err := s.resolveUserAndWebhookIDFromName(ctx, request.Webhook.Name)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.InvalidArgument, "invalid webhook name: %v", err)
|
||||
}
|
||||
userID := user.ID
|
||||
|
||||
currentUser, err := s.fetchCurrentUser(ctx)
|
||||
if err != nil {
|
||||
|
|
@ -828,14 +861,15 @@ func (s *APIV1Service) UpdateUserWebhook(ctx context.Context, request *v1pb.Upda
|
|||
return nil, status.Errorf(codes.Internal, "failed to update webhook: %v", err)
|
||||
}
|
||||
|
||||
return convertUserWebhookFromUserSetting(updatedWebhook, userID), nil
|
||||
return convertUserWebhookFromUserSetting(updatedWebhook, user), nil
|
||||
}
|
||||
|
||||
func (s *APIV1Service) DeleteUserWebhook(ctx context.Context, request *v1pb.DeleteUserWebhookRequest) (*emptypb.Empty, error) {
|
||||
webhookID, userID, err := parseUserWebhookName(request.Name)
|
||||
user, webhookID, err := s.resolveUserAndWebhookIDFromName(ctx, request.Name)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.InvalidArgument, "invalid webhook name: %v", err)
|
||||
}
|
||||
userID := user.ID
|
||||
|
||||
currentUser, err := s.fetchCurrentUser(ctx)
|
||||
if err != nil {
|
||||
|
|
@ -884,26 +918,10 @@ func generateUserWebhookID() string {
|
|||
return hex.EncodeToString(b)
|
||||
}
|
||||
|
||||
// parseUserWebhookName parses a webhook name and returns the webhook ID and user ID.
|
||||
// Format: users/{user}/webhooks/{webhook}.
|
||||
func parseUserWebhookName(name string) (string, int32, error) {
|
||||
parts := strings.Split(name, "/")
|
||||
if len(parts) != 4 || parts[0] != "users" || parts[2] != "webhooks" {
|
||||
return "", 0, errors.New("invalid webhook name format")
|
||||
}
|
||||
|
||||
userID, err := strconv.ParseInt(parts[1], 10, 32)
|
||||
if err != nil {
|
||||
return "", 0, errors.New("invalid user ID in webhook name")
|
||||
}
|
||||
|
||||
return parts[3], int32(userID), nil
|
||||
}
|
||||
|
||||
// convertUserWebhookFromUserSetting converts a storepb webhook to a v1pb UserWebhook.
|
||||
func convertUserWebhookFromUserSetting(webhook *storepb.WebhooksUserSetting_Webhook, userID int32) *v1pb.UserWebhook {
|
||||
func convertUserWebhookFromUserSetting(webhook *storepb.WebhooksUserSetting_Webhook, user *store.User) *v1pb.UserWebhook {
|
||||
return &v1pb.UserWebhook{
|
||||
Name: fmt.Sprintf("users/%d/webhooks/%s", userID, webhook.Id),
|
||||
Name: fmt.Sprintf("%s/webhooks/%s", BuildUserName(user.Username), webhook.Id),
|
||||
Url: webhook.Url,
|
||||
DisplayName: webhook.Title,
|
||||
// Note: create_time and update_time are not available in the user setting webhook structure
|
||||
|
|
@ -911,19 +929,21 @@ func convertUserWebhookFromUserSetting(webhook *storepb.WebhooksUserSetting_Webh
|
|||
}
|
||||
}
|
||||
|
||||
func convertUserFromStore(user *store.User) *v1pb.User {
|
||||
func convertUserFromStore(user *store.User, viewer *store.User) *v1pb.User {
|
||||
userpb := &v1pb.User{
|
||||
Name: fmt.Sprintf("%s%d", UserNamePrefix, user.ID),
|
||||
Name: BuildUserName(user.Username),
|
||||
State: convertStateFromStore(user.RowStatus),
|
||||
CreateTime: timestamppb.New(time.Unix(user.CreatedTs, 0)),
|
||||
UpdateTime: timestamppb.New(time.Unix(user.UpdatedTs, 0)),
|
||||
Role: convertUserRoleFromStore(user.Role),
|
||||
Username: user.Username,
|
||||
Email: user.Email,
|
||||
DisplayName: user.Nickname,
|
||||
AvatarUrl: user.AvatarURL,
|
||||
Description: user.Description,
|
||||
}
|
||||
if canViewerAccessUserEmail(viewer, user) {
|
||||
userpb.Email = user.Email
|
||||
}
|
||||
// Use the avatar URL instead of raw base64 image data to reduce the response size.
|
||||
if user.AvatarURL != "" {
|
||||
// Check if avatar url is base64 format.
|
||||
|
|
@ -937,6 +957,13 @@ func convertUserFromStore(user *store.User) *v1pb.User {
|
|||
return userpb
|
||||
}
|
||||
|
||||
func canViewerAccessUserEmail(viewer, user *store.User) bool {
|
||||
if viewer == nil || user == nil {
|
||||
return false
|
||||
}
|
||||
return viewer.Role == store.RoleAdmin || viewer.ID == user.ID
|
||||
}
|
||||
|
||||
func convertUserRoleFromStore(role store.Role) v1pb.User_Role {
|
||||
switch role {
|
||||
case store.RoleAdmin:
|
||||
|
|
@ -970,26 +997,6 @@ func extractImageInfo(dataURI string) (string, string, error) {
|
|||
return imageType, base64Data, nil
|
||||
}
|
||||
|
||||
// Helper functions for user settings
|
||||
|
||||
// ExtractUserIDAndSettingKeyFromName extracts user ID and setting key from resource name.
|
||||
// e.g., "users/123/settings/general" -> 123, "general".
|
||||
func ExtractUserIDAndSettingKeyFromName(name string) (int32, string, error) {
|
||||
// Expected format: users/{user}/settings/{setting}
|
||||
parts := strings.Split(name, "/")
|
||||
if len(parts) != 4 || parts[0] != "users" || parts[2] != "settings" {
|
||||
return 0, "", errors.Errorf("invalid resource name format: %s", name)
|
||||
}
|
||||
|
||||
userID, err := util.ConvertStringToInt32(parts[1])
|
||||
if err != nil {
|
||||
return 0, "", errors.Errorf("invalid user ID: %s", parts[1])
|
||||
}
|
||||
|
||||
settingKey := parts[3]
|
||||
return userID, settingKey, nil
|
||||
}
|
||||
|
||||
// convertSettingKeyToStore converts API setting key to store enum.
|
||||
func convertSettingKeyToStore(key string) (storepb.UserSetting_Key, error) {
|
||||
switch key {
|
||||
|
|
@ -1017,12 +1024,12 @@ func convertSettingKeyFromStore(key storepb.UserSetting_Key) string {
|
|||
}
|
||||
|
||||
// convertUserSettingFromStore converts store UserSetting to API UserSetting.
|
||||
func convertUserSettingFromStore(storeSetting *storepb.UserSetting, userID int32, key storepb.UserSetting_Key) *v1pb.UserSetting {
|
||||
func convertUserSettingFromStore(storeSetting *storepb.UserSetting, user *store.User, key storepb.UserSetting_Key) *v1pb.UserSetting {
|
||||
if storeSetting == nil {
|
||||
// Return default setting if none exists
|
||||
settingKey := convertSettingKeyFromStore(key)
|
||||
setting := &v1pb.UserSetting{
|
||||
Name: fmt.Sprintf("users/%d/settings/%s", userID, settingKey),
|
||||
Name: fmt.Sprintf("%s/settings/%s", BuildUserName(user.Username), settingKey),
|
||||
}
|
||||
|
||||
switch key {
|
||||
|
|
@ -1043,7 +1050,7 @@ func convertUserSettingFromStore(storeSetting *storepb.UserSetting, userID int32
|
|||
|
||||
settingKey := convertSettingKeyFromStore(storeSetting.Key)
|
||||
setting := &v1pb.UserSetting{
|
||||
Name: fmt.Sprintf("users/%d/settings/%s", userID, settingKey),
|
||||
Name: fmt.Sprintf("%s/settings/%s", BuildUserName(user.Username), settingKey),
|
||||
}
|
||||
|
||||
switch storeSetting.Key {
|
||||
|
|
@ -1063,14 +1070,17 @@ func convertUserSettingFromStore(storeSetting *storepb.UserSetting, userID int32
|
|||
}
|
||||
case storepb.UserSetting_WEBHOOKS:
|
||||
webhooks := storeSetting.GetWebhooks()
|
||||
apiWebhooks := make([]*v1pb.UserWebhook, 0, len(webhooks.Webhooks))
|
||||
for _, webhook := range webhooks.Webhooks {
|
||||
apiWebhook := &v1pb.UserWebhook{
|
||||
Name: fmt.Sprintf("users/%d/webhooks/%s", userID, webhook.Id),
|
||||
Url: webhook.Url,
|
||||
DisplayName: webhook.Title,
|
||||
apiWebhooks := make([]*v1pb.UserWebhook, 0)
|
||||
if webhooks != nil {
|
||||
apiWebhooks = make([]*v1pb.UserWebhook, 0, len(webhooks.Webhooks))
|
||||
for _, webhook := range webhooks.Webhooks {
|
||||
apiWebhook := &v1pb.UserWebhook{
|
||||
Name: fmt.Sprintf("%s/webhooks/%s", BuildUserName(user.Username), webhook.Id),
|
||||
Url: webhook.Url,
|
||||
DisplayName: webhook.Title,
|
||||
}
|
||||
apiWebhooks = append(apiWebhooks, apiWebhook)
|
||||
}
|
||||
apiWebhooks = append(apiWebhooks, apiWebhook)
|
||||
}
|
||||
setting.Value = &v1pb.UserSetting_WebhooksSetting_{
|
||||
WebhooksSetting: &v1pb.UserSetting_WebhooksSetting{
|
||||
|
|
@ -1240,10 +1250,11 @@ func extractUsernameFromComparison(left, right ast.Expr) (string, bool) {
|
|||
// Notifications are backed by the inbox storage layer and represent activities
|
||||
// that require user attention (e.g., memo comments).
|
||||
func (s *APIV1Service) ListUserNotifications(ctx context.Context, request *v1pb.ListUserNotificationsRequest) (*v1pb.ListUserNotificationsResponse, error) {
|
||||
userID, err := ExtractUserIDFromName(request.Parent)
|
||||
user, err := s.resolveUserFromName(ctx, request.Parent)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.InvalidArgument, "invalid user name: %v", err)
|
||||
}
|
||||
userID := user.ID
|
||||
|
||||
// Verify the requesting user has permission to view these notifications
|
||||
currentUser, err := s.fetchCurrentUser(ctx)
|
||||
|
|
@ -1268,10 +1279,19 @@ func (s *APIV1Service) ListUserNotifications(ctx context.Context, request *v1pb.
|
|||
return nil, status.Errorf(codes.Internal, "failed to list inboxes: %v", err)
|
||||
}
|
||||
|
||||
// Convert storage layer inboxes to API notifications
|
||||
// Convert storage layer inboxes to API notifications.
|
||||
userIDs := make([]int32, 0, len(inboxes)*2)
|
||||
for _, inbox := range inboxes {
|
||||
userIDs = append(userIDs, inbox.ReceiverID, inbox.SenderID)
|
||||
}
|
||||
usersByID, err := s.listUsersByID(ctx, userIDs)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to list notification users: %v", err)
|
||||
}
|
||||
|
||||
notifications := []*v1pb.UserNotification{}
|
||||
for _, inbox := range inboxes {
|
||||
notification, err := s.convertInboxToUserNotification(ctx, inbox)
|
||||
notification, err := s.convertInboxToUserNotificationWithUsers(ctx, inbox, usersByID)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to convert inbox: %v", err)
|
||||
}
|
||||
|
|
@ -1290,7 +1310,7 @@ func (s *APIV1Service) UpdateUserNotification(ctx context.Context, request *v1pb
|
|||
return nil, status.Errorf(codes.InvalidArgument, "notification is required")
|
||||
}
|
||||
|
||||
notificationID, err := ExtractNotificationIDFromName(request.Notification.Name)
|
||||
user, notificationID, err := s.resolveUserAndNotificationIDFromName(ctx, request.Notification.Name)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.InvalidArgument, "invalid notification name: %v", err)
|
||||
}
|
||||
|
|
@ -1303,6 +1323,9 @@ func (s *APIV1Service) UpdateUserNotification(ctx context.Context, request *v1pb
|
|||
if currentUser == nil {
|
||||
return nil, status.Errorf(codes.Unauthenticated, "user not authenticated")
|
||||
}
|
||||
if currentUser.ID != user.ID {
|
||||
return nil, status.Errorf(codes.PermissionDenied, "permission denied")
|
||||
}
|
||||
// Verify ownership before updating
|
||||
inboxes, err := s.Store.ListInboxes(ctx, &store.FindInbox{
|
||||
ID: ¬ificationID,
|
||||
|
|
@ -1358,7 +1381,7 @@ func (s *APIV1Service) UpdateUserNotification(ctx context.Context, request *v1pb
|
|||
// DeleteUserNotification permanently deletes a notification.
|
||||
// Only the notification owner can delete their notifications.
|
||||
func (s *APIV1Service) DeleteUserNotification(ctx context.Context, request *v1pb.DeleteUserNotificationRequest) (*emptypb.Empty, error) {
|
||||
notificationID, err := ExtractNotificationIDFromName(request.Name)
|
||||
user, notificationID, err := s.resolveUserAndNotificationIDFromName(ctx, request.Name)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.InvalidArgument, "invalid notification name: %v", err)
|
||||
}
|
||||
|
|
@ -1371,6 +1394,9 @@ func (s *APIV1Service) DeleteUserNotification(ctx context.Context, request *v1pb
|
|||
if currentUser == nil {
|
||||
return nil, status.Errorf(codes.Unauthenticated, "user not authenticated")
|
||||
}
|
||||
if currentUser.ID != user.ID {
|
||||
return nil, status.Errorf(codes.PermissionDenied, "permission denied")
|
||||
}
|
||||
// Verify ownership before deletion
|
||||
inboxes, err := s.Store.ListInboxes(ctx, &store.FindInbox{
|
||||
ID: ¬ificationID,
|
||||
|
|
@ -1398,9 +1424,26 @@ func (s *APIV1Service) DeleteUserNotification(ctx context.Context, request *v1pb
|
|||
// convertInboxToUserNotification converts a storage-layer inbox to an API notification.
|
||||
// This handles the mapping between the internal inbox representation and the public API.
|
||||
func (s *APIV1Service) convertInboxToUserNotification(ctx context.Context, inbox *store.Inbox) (*v1pb.UserNotification, error) {
|
||||
usersByID, err := s.listUsersByID(ctx, []int32{inbox.ReceiverID, inbox.SenderID})
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to list notification users: %v", err)
|
||||
}
|
||||
return s.convertInboxToUserNotificationWithUsers(ctx, inbox, usersByID)
|
||||
}
|
||||
|
||||
func (s *APIV1Service) convertInboxToUserNotificationWithUsers(ctx context.Context, inbox *store.Inbox, usersByID map[int32]*store.User) (*v1pb.UserNotification, error) {
|
||||
receiver := usersByID[inbox.ReceiverID]
|
||||
if receiver == nil {
|
||||
return nil, status.Errorf(codes.NotFound, "notification receiver not found")
|
||||
}
|
||||
sender := usersByID[inbox.SenderID]
|
||||
if sender == nil {
|
||||
return nil, status.Errorf(codes.NotFound, "notification sender not found")
|
||||
}
|
||||
|
||||
notification := &v1pb.UserNotification{
|
||||
Name: fmt.Sprintf("users/%d/notifications/%d", inbox.ReceiverID, inbox.ID),
|
||||
Sender: fmt.Sprintf("%s%d", UserNamePrefix, inbox.SenderID),
|
||||
Name: fmt.Sprintf("%s/notifications/%d", BuildUserName(receiver.Username), inbox.ID),
|
||||
Sender: BuildUserName(sender.Username),
|
||||
CreateTime: timestamppb.New(time.Unix(inbox.CreatedTs, 0)),
|
||||
}
|
||||
|
||||
|
|
@ -1470,20 +1513,3 @@ func (s *APIV1Service) convertUserNotificationPayload(ctx context.Context, messa
|
|||
RelatedMemo: fmt.Sprintf("%s%s", MemoNamePrefix, relatedMemo.UID),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// ExtractNotificationIDFromName extracts the notification ID from a resource name.
|
||||
// Expected format: users/{user_id}/notifications/{notification_id}.
|
||||
func ExtractNotificationIDFromName(name string) (int32, error) {
|
||||
pattern := regexp.MustCompile(`^users/(\d+)/notifications/(\d+)$`)
|
||||
matches := pattern.FindStringSubmatch(name)
|
||||
if len(matches) != 3 {
|
||||
return 0, errors.Errorf("invalid notification name: %s", name)
|
||||
}
|
||||
|
||||
id, err := strconv.Atoi(matches[2])
|
||||
if err != nil {
|
||||
return 0, errors.Errorf("invalid notification id: %s", matches[2])
|
||||
}
|
||||
|
||||
return int32(id), nil
|
||||
}
|
||||
|
|
|
|||
|
|
@ -14,6 +14,46 @@ import (
|
|||
"github.com/usememos/memos/store"
|
||||
)
|
||||
|
||||
func (s *APIV1Service) listUsersByID(ctx context.Context, userIDs []int32) (map[int32]*store.User, error) {
|
||||
if len(userIDs) == 0 {
|
||||
return map[int32]*store.User{}, nil
|
||||
}
|
||||
|
||||
uniqueUserIDs := make([]int32, 0, len(userIDs))
|
||||
seenUserIDs := make(map[int32]struct{}, len(userIDs))
|
||||
for _, userID := range userIDs {
|
||||
if _, seen := seenUserIDs[userID]; seen {
|
||||
continue
|
||||
}
|
||||
seenUserIDs[userID] = struct{}{}
|
||||
uniqueUserIDs = append(uniqueUserIDs, userID)
|
||||
}
|
||||
|
||||
users, err := s.Store.ListUsers(ctx, &store.FindUser{IDList: uniqueUserIDs})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
usersByID := make(map[int32]*store.User, len(users))
|
||||
for _, user := range users {
|
||||
usersByID[user.ID] = user
|
||||
}
|
||||
return usersByID, nil
|
||||
}
|
||||
|
||||
func (s *APIV1Service) listUsernamesByID(ctx context.Context, userIDs []int32) (map[int32]string, error) {
|
||||
usersByID, err := s.listUsersByID(ctx, userIDs)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
usernamesByID := make(map[int32]string, len(usersByID))
|
||||
for _, user := range usersByID {
|
||||
usernamesByID[user.ID] = user.Username
|
||||
}
|
||||
return usernamesByID, nil
|
||||
}
|
||||
|
||||
func (s *APIV1Service) ListAllUserStats(ctx context.Context, _ *v1pb.ListAllUserStatsRequest) (*v1pb.ListAllUserStatsResponse, error) {
|
||||
instanceMemoRelatedSetting, err := s.Store.GetInstanceMemoRelatedSetting(ctx)
|
||||
if err != nil {
|
||||
|
|
@ -44,6 +84,7 @@ func (s *APIV1Service) ListAllUserStats(ctx context.Context, _ *v1pb.ListAllUser
|
|||
}
|
||||
|
||||
userMemoStatMap := make(map[int32]*v1pb.UserStats)
|
||||
pinnedMemoIDsByUserID := make(map[int32][]int32)
|
||||
limit := 1000
|
||||
offset := 0
|
||||
memoFind.Limit = &limit
|
||||
|
|
@ -62,7 +103,7 @@ func (s *APIV1Service) ListAllUserStats(ctx context.Context, _ *v1pb.ListAllUser
|
|||
// Initialize user stats if not exists
|
||||
if _, exists := userMemoStatMap[memo.CreatorID]; !exists {
|
||||
userMemoStatMap[memo.CreatorID] = &v1pb.UserStats{
|
||||
Name: fmt.Sprintf("users/%d/stats", memo.CreatorID),
|
||||
Name: "",
|
||||
TagCount: make(map[string]int32),
|
||||
MemoDisplayTimestamps: []*timestamppb.Timestamp{},
|
||||
PinnedMemos: []string{},
|
||||
|
|
@ -110,7 +151,7 @@ func (s *APIV1Service) ListAllUserStats(ctx context.Context, _ *v1pb.ListAllUser
|
|||
|
||||
// Track pinned memos
|
||||
if memo.Pinned {
|
||||
stats.PinnedMemos = append(stats.PinnedMemos, fmt.Sprintf("users/%d/memos/%d", memo.CreatorID, memo.ID))
|
||||
pinnedMemoIDsByUserID[memo.CreatorID] = append(pinnedMemoIDsByUserID[memo.CreatorID], memo.ID)
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -118,7 +159,23 @@ func (s *APIV1Service) ListAllUserStats(ctx context.Context, _ *v1pb.ListAllUser
|
|||
}
|
||||
|
||||
userMemoStats := []*v1pb.UserStats{}
|
||||
for _, userMemoStat := range userMemoStatMap {
|
||||
userIDs := make([]int32, 0, len(userMemoStatMap))
|
||||
for userID := range userMemoStatMap {
|
||||
userIDs = append(userIDs, userID)
|
||||
}
|
||||
usernamesByID, err := s.listUsernamesByID(ctx, userIDs)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to list users: %v", err)
|
||||
}
|
||||
for userID, userMemoStat := range userMemoStatMap {
|
||||
username, ok := usernamesByID[userID]
|
||||
if !ok {
|
||||
return nil, status.Errorf(codes.Internal, "failed to resolve user stats name")
|
||||
}
|
||||
userMemoStat.Name = fmt.Sprintf("%s/stats", BuildUserName(username))
|
||||
for _, memoID := range pinnedMemoIDsByUserID[userID] {
|
||||
userMemoStat.PinnedMemos = append(userMemoStat.PinnedMemos, fmt.Sprintf("%s/memos/%d", BuildUserName(username), memoID))
|
||||
}
|
||||
userMemoStats = append(userMemoStats, userMemoStat)
|
||||
}
|
||||
|
||||
|
|
@ -129,10 +186,14 @@ func (s *APIV1Service) ListAllUserStats(ctx context.Context, _ *v1pb.ListAllUser
|
|||
}
|
||||
|
||||
func (s *APIV1Service) GetUserStats(ctx context.Context, request *v1pb.GetUserStatsRequest) (*v1pb.UserStats, error) {
|
||||
userID, err := ExtractUserIDFromName(request.Name)
|
||||
user, err := ResolveUserByName(ctx, s.Store, request.Name)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.InvalidArgument, "invalid user name: %v", err)
|
||||
}
|
||||
if user == nil {
|
||||
return nil, status.Errorf(codes.NotFound, "user not found")
|
||||
}
|
||||
userID := user.ID
|
||||
|
||||
currentUser, err := s.fetchCurrentUser(ctx)
|
||||
if err != nil {
|
||||
|
|
@ -211,7 +272,7 @@ func (s *APIV1Service) GetUserStats(ctx context.Context, request *v1pb.GetUserSt
|
|||
}
|
||||
}
|
||||
if memo.Pinned {
|
||||
pinnedMemos = append(pinnedMemos, fmt.Sprintf("users/%d/memos/%d", userID, memo.ID))
|
||||
pinnedMemos = append(pinnedMemos, fmt.Sprintf("%s/memos/%d", BuildUserName(user.Username), memo.ID))
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -219,7 +280,7 @@ func (s *APIV1Service) GetUserStats(ctx context.Context, request *v1pb.GetUserSt
|
|||
}
|
||||
|
||||
userStats := &v1pb.UserStats{
|
||||
Name: fmt.Sprintf("users/%d/stats", userID),
|
||||
Name: fmt.Sprintf("%s/stats", BuildUserName(user.Username)),
|
||||
MemoDisplayTimestamps: displayTimestamps,
|
||||
TagCount: tagCount,
|
||||
PinnedMemos: pinnedMemos,
|
||||
|
|
|
|||
|
|
@ -114,9 +114,7 @@ func (s *APIV1Service) RegisterGateway(ctx context.Context, echoServer *echo.Ech
|
|||
AllowOrigins: []string{"*"},
|
||||
}))
|
||||
// Register SSE endpoint with same CORS as rest of /api/v1.
|
||||
gwGroup.GET("/api/v1/sse", func(c *echo.Context) error {
|
||||
return handleSSE(c, s.SSEHub, auth.NewAuthenticator(s.Store, s.Secret))
|
||||
})
|
||||
RegisterSSERoutes(gwGroup, s.SSEHub, s.Store, s.Secret)
|
||||
handler := echo.WrapHandler(gwMux)
|
||||
|
||||
gwGroup.Any("/api/v1/*", handler)
|
||||
|
|
|
|||
|
|
@ -286,9 +286,6 @@ See SAFARI_FIX.md for recommended test coverage.
|
|||
# Test attachment
|
||||
curl "http://localhost:8081/file/attachments/{uid}/file.jpg"
|
||||
|
||||
# Test avatar by ID
|
||||
curl "http://localhost:8081/file/users/1/avatar"
|
||||
|
||||
# Test avatar by username
|
||||
curl "http://localhost:8081/file/users/steven/avatar"
|
||||
|
||||
|
|
|
|||
|
|
@ -20,7 +20,6 @@ import (
|
|||
"golang.org/x/sync/semaphore"
|
||||
|
||||
"github.com/usememos/memos/internal/profile"
|
||||
"github.com/usememos/memos/internal/util"
|
||||
"github.com/usememos/memos/plugin/storage/s3"
|
||||
storepb "github.com/usememos/memos/proto/gen/store"
|
||||
"github.com/usememos/memos/server/auth"
|
||||
|
|
@ -154,7 +153,7 @@ func (s *FileServerService) serveUserAvatar(c *echo.Context) error {
|
|||
ctx := c.Request().Context()
|
||||
identifier := c.Param("identifier")
|
||||
|
||||
user, err := s.getUserByIdentifier(ctx, identifier)
|
||||
user, err := s.getUserByUsername(ctx, identifier)
|
||||
if err != nil {
|
||||
return echo.NewHTTPError(http.StatusInternalServerError, "failed to get user").Wrap(err)
|
||||
}
|
||||
|
|
@ -530,11 +529,8 @@ func (s *FileServerService) getCurrentUser(ctx context.Context, c *echo.Context)
|
|||
return s.authenticator.AuthenticateToUser(ctx, authHeader, cookieHeader)
|
||||
}
|
||||
|
||||
// getUserByIdentifier finds a user by either ID or username.
|
||||
func (s *FileServerService) getUserByIdentifier(ctx context.Context, identifier string) (*store.User, error) {
|
||||
if userID, err := util.ConvertStringToInt32(identifier); err == nil {
|
||||
return s.Store.GetUser(ctx, &store.FindUser{ID: &userID})
|
||||
}
|
||||
// getUserByUsername finds a user by username only.
|
||||
func (s *FileServerService) getUserByUsername(ctx context.Context, identifier string) (*store.User, error) {
|
||||
return s.Store.GetUser(ctx, &store.FindUser{Username: &identifier})
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -7,6 +7,7 @@ This package implements a [Model Context Protocol (MCP)](https://modelcontextpro
|
|||
```
|
||||
POST /mcp (tool calls, initialize)
|
||||
GET /mcp (optional SSE stream for server-to-client messages)
|
||||
DELETE /mcp (optional session termination)
|
||||
```
|
||||
|
||||
Transport: [Streamable HTTP](https://modelcontextprotocol.io/specification/2025-03-26/basic/transports) (single endpoint, MCP spec 2025-03-26).
|
||||
|
|
@ -24,13 +25,22 @@ The server advertises the following MCP capabilities:
|
|||
|
||||
## Authentication
|
||||
|
||||
Every request must include a Personal Access Token (PAT):
|
||||
Public reads can be used without authentication. Personal Access Tokens (PATs) or short-lived JWT session tokens are required for:
|
||||
|
||||
- Reading non-public memos or attachments
|
||||
- Any tool that mutates data
|
||||
|
||||
When authenticating, send a Bearer token:
|
||||
|
||||
```
|
||||
Authorization: Bearer <your-PAT>
|
||||
```
|
||||
|
||||
PATs are long-lived tokens created in Settings → My Account → Access Tokens. Short-lived JWT session tokens are also accepted. Requests without a valid token receive `HTTP 401`.
|
||||
PATs are long-lived tokens created in Settings → My Account → Access Tokens. Short-lived JWT session tokens are also accepted. Requests with an invalid token receive `HTTP 401`.
|
||||
|
||||
## Origin Validation
|
||||
|
||||
For Streamable HTTP safety, requests with an `Origin` header must be same-origin with the current request host or match the configured `instance-url`. Requests without an `Origin` header, such as desktop MCP clients and CLI tools, are allowed.
|
||||
|
||||
## Tools
|
||||
|
||||
|
|
@ -38,7 +48,7 @@ PATs are long-lived tokens created in Settings → My Account → Access Tokens.
|
|||
|
||||
| Tool | Description | Required params | Optional params |
|
||||
|---|---|---|---|
|
||||
| `list_memos` | List memos | — | `page_size`, `page`, `state`, `order_by_pinned`, `filter` (CEL) |
|
||||
| `list_memos` | List memos | — | `page_size`, `page`, `state`, `order_by_pinned`, `filter` (supported subset of standard CEL syntax) |
|
||||
| `get_memo` | Get a single memo | `name` | — |
|
||||
| `search_memos` | Full-text search | `query` | — |
|
||||
| `create_memo` | Create a memo | `content` | `visibility` |
|
||||
|
|
@ -60,15 +70,15 @@ PATs are long-lived tokens created in Settings → My Account → Access Tokens.
|
|||
| `list_attachments` | List user's attachments | — | `page_size`, `page`, `memo` |
|
||||
| `get_attachment` | Get attachment metadata | `name` | — |
|
||||
| `delete_attachment` | Delete an attachment | `name` | — |
|
||||
| `link_attachment_to_memo` | Link attachment to memo | `name`, `memo` | — |
|
||||
| `link_attachment_to_memo` | Link attachment to a memo you own | `name`, `memo` | — |
|
||||
|
||||
### Relation Tools
|
||||
|
||||
| Tool | Description | Required params | Optional params |
|
||||
|---|---|---|---|
|
||||
| `list_memo_relations` | List relations (refs + comments) | `name` | `type` |
|
||||
| `create_memo_relation` | Create a reference relation | `name`, `related_memo` | — |
|
||||
| `delete_memo_relation` | Delete a reference relation | `name`, `related_memo` | — |
|
||||
| `create_memo_relation` | Create a reference relation from a memo you own to a memo you can read | `name`, `related_memo` | — |
|
||||
| `delete_memo_relation` | Delete a reference relation from a memo you own | `name`, `related_memo` | — |
|
||||
|
||||
### Reaction Tools
|
||||
|
||||
|
|
|
|||
|
|
@ -0,0 +1,117 @@
|
|||
package mcp
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
|
||||
"github.com/usememos/memos/store"
|
||||
)
|
||||
|
||||
// checkMemoAccess returns an error if the caller cannot read the memo.
|
||||
// userID == 0 means anonymous.
|
||||
func checkMemoAccess(memo *store.Memo, userID int32) error {
|
||||
if memo.RowStatus == store.Archived && memo.CreatorID != userID {
|
||||
return errors.New("permission denied")
|
||||
}
|
||||
|
||||
switch memo.Visibility {
|
||||
case store.Protected:
|
||||
if userID == 0 {
|
||||
return errors.New("permission denied")
|
||||
}
|
||||
case store.Private:
|
||||
if memo.CreatorID != userID {
|
||||
return errors.New("permission denied")
|
||||
}
|
||||
default:
|
||||
// store.Public and any unknown visibility: allow.
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func checkMemoOwnership(memo *store.Memo, userID int32) error {
|
||||
if memo.CreatorID != userID {
|
||||
return errors.New("permission denied")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func hasMemoOwnership(memo *store.Memo, userID int32) bool {
|
||||
return memo.CreatorID == userID
|
||||
}
|
||||
|
||||
// applyVisibilityFilter restricts find to memos the caller may see.
|
||||
func applyVisibilityFilter(find *store.FindMemo, userID int32, rowStatus *store.RowStatus) {
|
||||
if rowStatus != nil && *rowStatus == store.Archived {
|
||||
if userID == 0 {
|
||||
impossibleCreatorID := int32(-1)
|
||||
find.CreatorID = &impossibleCreatorID
|
||||
return
|
||||
}
|
||||
find.CreatorID = &userID
|
||||
return
|
||||
}
|
||||
if userID == 0 {
|
||||
find.VisibilityList = []store.Visibility{store.Public}
|
||||
return
|
||||
}
|
||||
find.Filters = append(find.Filters, "creator_id == "+itoa32(userID)+` || visibility in ["PUBLIC", "PROTECTED"]`)
|
||||
}
|
||||
|
||||
func (s *MCPService) checkAttachmentAccess(ctx context.Context, attachment *store.Attachment, userID int32) error {
|
||||
if attachment.CreatorID == userID {
|
||||
return nil
|
||||
}
|
||||
if attachment.MemoID == nil {
|
||||
return errors.New("permission denied")
|
||||
}
|
||||
|
||||
memo, err := s.store.GetMemo(ctx, &store.FindMemo{ID: attachment.MemoID})
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "failed to get linked memo")
|
||||
}
|
||||
if memo == nil {
|
||||
return errors.New("linked memo not found")
|
||||
}
|
||||
return checkMemoAccess(memo, userID)
|
||||
}
|
||||
|
||||
func (s *MCPService) isAllowedOrigin(r *http.Request) bool {
|
||||
origin := r.Header.Get("Origin")
|
||||
if origin == "" {
|
||||
return true
|
||||
}
|
||||
|
||||
originURL, err := url.Parse(origin)
|
||||
if err != nil || originURL.Scheme == "" || originURL.Host == "" {
|
||||
return false
|
||||
}
|
||||
|
||||
if sameOriginHost(originURL.Host, r.Host) {
|
||||
return true
|
||||
}
|
||||
|
||||
if s.profile.InstanceURL == "" {
|
||||
return false
|
||||
}
|
||||
|
||||
instanceURL, err := url.Parse(s.profile.InstanceURL)
|
||||
if err != nil || instanceURL.Scheme == "" || instanceURL.Host == "" {
|
||||
return false
|
||||
}
|
||||
|
||||
return strings.EqualFold(originURL.Scheme, instanceURL.Scheme) && sameOriginHost(originURL.Host, instanceURL.Host)
|
||||
}
|
||||
|
||||
func sameOriginHost(a, b string) bool {
|
||||
return strings.EqualFold(a, b)
|
||||
}
|
||||
|
||||
func itoa32(v int32) string {
|
||||
return strconv.FormatInt(int64(v), 10)
|
||||
}
|
||||
|
|
@ -4,7 +4,6 @@ import (
|
|||
"net/http"
|
||||
|
||||
"github.com/labstack/echo/v5"
|
||||
"github.com/labstack/echo/v5/middleware"
|
||||
mcpserver "github.com/mark3labs/mcp-go/server"
|
||||
|
||||
"github.com/usememos/memos/internal/profile"
|
||||
|
|
@ -44,11 +43,22 @@ func (s *MCPService) RegisterRoutes(echoServer *echo.Echo) {
|
|||
httpHandler := mcpserver.NewStreamableHTTPServer(mcpSrv)
|
||||
|
||||
mcpGroup := echoServer.Group("")
|
||||
mcpGroup.Use(middleware.CORSWithConfig(middleware.CORSConfig{
|
||||
AllowOrigins: []string{"*"},
|
||||
}))
|
||||
mcpGroup.Use(func(next echo.HandlerFunc) echo.HandlerFunc {
|
||||
return func(c *echo.Context) error {
|
||||
if !s.isAllowedOrigin(c.Request()) {
|
||||
return c.JSON(http.StatusForbidden, map[string]string{"message": "invalid origin"})
|
||||
}
|
||||
if origin := c.Request().Header.Get("Origin"); origin != "" {
|
||||
headers := c.Response().Header()
|
||||
headers.Set("Vary", "Origin")
|
||||
headers.Set("Access-Control-Allow-Origin", origin)
|
||||
headers.Set("Access-Control-Allow-Headers", "Authorization, Content-Type, Accept, Mcp-Session-Id, MCP-Protocol-Version, Last-Event-ID")
|
||||
headers.Set("Access-Control-Allow-Methods", "GET, POST, DELETE, OPTIONS")
|
||||
if c.Request().Method == http.MethodOptions {
|
||||
return c.NoContent(http.StatusNoContent)
|
||||
}
|
||||
}
|
||||
|
||||
authHeader := c.Request().Header.Get("Authorization")
|
||||
if authHeader != "" {
|
||||
result := s.authenticator.Authenticate(c.Request().Context(), authHeader)
|
||||
|
|
|
|||
|
|
@ -0,0 +1,275 @@
|
|||
package mcp
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/lithammer/shortuuid/v4"
|
||||
"github.com/mark3labs/mcp-go/mcp"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/usememos/memos/internal/profile"
|
||||
storepb "github.com/usememos/memos/proto/gen/store"
|
||||
"github.com/usememos/memos/server/auth"
|
||||
"github.com/usememos/memos/store"
|
||||
teststore "github.com/usememos/memos/store/test"
|
||||
)
|
||||
|
||||
type testMCPService struct {
|
||||
service *MCPService
|
||||
store *store.Store
|
||||
}
|
||||
|
||||
func newTestMCPService(t *testing.T) *testMCPService {
|
||||
t.Helper()
|
||||
|
||||
ctx := context.Background()
|
||||
stores := teststore.NewTestingStore(ctx, t)
|
||||
t.Cleanup(func() {
|
||||
require.NoError(t, stores.Close())
|
||||
})
|
||||
|
||||
svc := NewMCPService(&profile.Profile{
|
||||
Driver: "sqlite",
|
||||
InstanceURL: "https://notes.example.com",
|
||||
}, stores, "test-secret")
|
||||
return &testMCPService{
|
||||
service: svc,
|
||||
store: stores,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *testMCPService) createUser(t *testing.T, username string) *store.User {
|
||||
t.Helper()
|
||||
|
||||
user, err := s.store.CreateUser(context.Background(), &store.User{
|
||||
Username: username,
|
||||
Role: store.RoleUser,
|
||||
Email: username + "@example.com",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
return user
|
||||
}
|
||||
|
||||
func (s *testMCPService) createMemo(t *testing.T, creatorID int32, visibility store.Visibility, content string) *store.Memo {
|
||||
t.Helper()
|
||||
|
||||
memo, err := s.store.CreateMemo(context.Background(), &store.Memo{
|
||||
UID: shortuuid.New(),
|
||||
CreatorID: creatorID,
|
||||
RowStatus: store.Normal,
|
||||
Visibility: visibility,
|
||||
Content: content,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
return memo
|
||||
}
|
||||
|
||||
func (s *testMCPService) archiveMemo(t *testing.T, memoID int32) {
|
||||
t.Helper()
|
||||
|
||||
rowStatus := store.Archived
|
||||
require.NoError(t, s.store.UpdateMemo(context.Background(), &store.UpdateMemo{
|
||||
ID: memoID,
|
||||
RowStatus: &rowStatus,
|
||||
}))
|
||||
}
|
||||
|
||||
func (s *testMCPService) createAttachment(t *testing.T, creatorID int32, memoID *int32) *store.Attachment {
|
||||
t.Helper()
|
||||
|
||||
attachment, err := s.store.CreateAttachment(context.Background(), &store.Attachment{
|
||||
UID: shortuuid.New(),
|
||||
CreatorID: creatorID,
|
||||
Filename: "note.txt",
|
||||
Type: "text/plain",
|
||||
Size: 4,
|
||||
StorageType: storepb.AttachmentStorageType_ATTACHMENT_STORAGE_TYPE_UNSPECIFIED,
|
||||
Reference: "db://attachment/note.txt",
|
||||
MemoID: memoID,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
return attachment
|
||||
}
|
||||
|
||||
func withUser(ctx context.Context, userID int32) context.Context {
|
||||
return context.WithValue(ctx, auth.UserIDContextKey, userID)
|
||||
}
|
||||
|
||||
func toolRequest(name string, arguments map[string]any) mcp.CallToolRequest {
|
||||
return mcp.CallToolRequest{
|
||||
Params: mcp.CallToolParams{
|
||||
Name: name,
|
||||
Arguments: arguments,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func firstText(t *testing.T, result *mcp.CallToolResult) string {
|
||||
t.Helper()
|
||||
require.NotEmpty(t, result.Content)
|
||||
text, ok := result.Content[0].(mcp.TextContent)
|
||||
require.True(t, ok)
|
||||
return text.Text
|
||||
}
|
||||
|
||||
func TestHandleGetMemoAndReadResourceDenyArchivedMemoToNonCreator(t *testing.T) {
|
||||
ts := newTestMCPService(t)
|
||||
owner := ts.createUser(t, "owner")
|
||||
other := ts.createUser(t, "other")
|
||||
|
||||
memo := ts.createMemo(t, owner.ID, store.Public, "archived")
|
||||
ts.archiveMemo(t, memo.ID)
|
||||
|
||||
ctx := withUser(context.Background(), other.ID)
|
||||
result, err := ts.service.handleGetMemo(ctx, toolRequest("get_memo", map[string]any{
|
||||
"name": "memos/" + memo.UID,
|
||||
}))
|
||||
require.NoError(t, err)
|
||||
require.True(t, result.IsError)
|
||||
require.Contains(t, firstText(t, result), "permission denied")
|
||||
|
||||
_, err = ts.service.handleReadMemoResource(ctx, mcp.ReadResourceRequest{
|
||||
Params: mcp.ReadResourceParams{
|
||||
URI: "memo://memos/" + memo.UID,
|
||||
},
|
||||
})
|
||||
require.ErrorContains(t, err, "permission denied")
|
||||
}
|
||||
|
||||
func TestHandleListMemosArchivedOnlyReturnsCreatorMemos(t *testing.T) {
|
||||
ts := newTestMCPService(t)
|
||||
owner := ts.createUser(t, "owner")
|
||||
other := ts.createUser(t, "other")
|
||||
|
||||
ownerMemo := ts.createMemo(t, owner.ID, store.Public, "owner archived")
|
||||
ts.archiveMemo(t, ownerMemo.ID)
|
||||
otherMemo := ts.createMemo(t, other.ID, store.Public, "other archived")
|
||||
ts.archiveMemo(t, otherMemo.ID)
|
||||
|
||||
result, err := ts.service.handleListMemos(withUser(context.Background(), owner.ID), toolRequest("list_memos", map[string]any{
|
||||
"state": "ARCHIVED",
|
||||
}))
|
||||
require.NoError(t, err)
|
||||
require.False(t, result.IsError)
|
||||
|
||||
var payload struct {
|
||||
Memos []memoJSON `json:"memos"`
|
||||
}
|
||||
require.NoError(t, json.Unmarshal([]byte(firstText(t, result)), &payload))
|
||||
require.Len(t, payload.Memos, 1)
|
||||
require.Equal(t, "memos/"+ownerMemo.UID, payload.Memos[0].Name)
|
||||
|
||||
anonResult, err := ts.service.handleListMemos(context.Background(), toolRequest("list_memos", map[string]any{
|
||||
"state": "ARCHIVED",
|
||||
}))
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, json.Unmarshal([]byte(firstText(t, anonResult)), &payload))
|
||||
require.Empty(t, payload.Memos)
|
||||
}
|
||||
|
||||
func TestHandleListMemoRelationsFiltersUnreadableTargets(t *testing.T) {
|
||||
ts := newTestMCPService(t)
|
||||
owner := ts.createUser(t, "owner")
|
||||
privateUser := ts.createUser(t, "private-user")
|
||||
publicUser := ts.createUser(t, "public-user")
|
||||
|
||||
source := ts.createMemo(t, owner.ID, store.Public, "source")
|
||||
privateTarget := ts.createMemo(t, privateUser.ID, store.Private, "private")
|
||||
publicTarget := ts.createMemo(t, publicUser.ID, store.Public, "public")
|
||||
|
||||
_, err := ts.store.UpsertMemoRelation(context.Background(), &store.MemoRelation{
|
||||
MemoID: source.ID,
|
||||
RelatedMemoID: privateTarget.ID,
|
||||
Type: store.MemoRelationReference,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
_, err = ts.store.UpsertMemoRelation(context.Background(), &store.MemoRelation{
|
||||
MemoID: source.ID,
|
||||
RelatedMemoID: publicTarget.ID,
|
||||
Type: store.MemoRelationReference,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
result, err := ts.service.handleListMemoRelations(context.Background(), toolRequest("list_memo_relations", map[string]any{
|
||||
"name": "memos/" + source.UID,
|
||||
}))
|
||||
require.NoError(t, err)
|
||||
require.False(t, result.IsError)
|
||||
|
||||
var relations []relationJSON
|
||||
require.NoError(t, json.Unmarshal([]byte(firstText(t, result)), &relations))
|
||||
require.Len(t, relations, 1)
|
||||
require.Equal(t, "memos/"+publicTarget.UID, relations[0].RelatedMemo)
|
||||
|
||||
denied, err := ts.service.handleListMemoRelations(context.Background(), toolRequest("list_memo_relations", map[string]any{
|
||||
"name": "memos/" + privateTarget.UID,
|
||||
}))
|
||||
require.NoError(t, err)
|
||||
require.True(t, denied.IsError)
|
||||
require.Contains(t, firstText(t, denied), "permission denied")
|
||||
}
|
||||
|
||||
func TestHandleLinkAttachmentToMemoRequiresMemoOwnership(t *testing.T) {
|
||||
ts := newTestMCPService(t)
|
||||
attachmentOwner := ts.createUser(t, "attachment-owner")
|
||||
memoOwner := ts.createUser(t, "memo-owner")
|
||||
|
||||
attachment := ts.createAttachment(t, attachmentOwner.ID, nil)
|
||||
memo := ts.createMemo(t, memoOwner.ID, store.Public, "target")
|
||||
|
||||
result, err := ts.service.handleLinkAttachmentToMemo(withUser(context.Background(), attachmentOwner.ID), toolRequest("link_attachment_to_memo", map[string]any{
|
||||
"name": "attachments/" + attachment.UID,
|
||||
"memo": "memos/" + memo.UID,
|
||||
}))
|
||||
require.NoError(t, err)
|
||||
require.True(t, result.IsError)
|
||||
require.Contains(t, firstText(t, result), "permission denied")
|
||||
}
|
||||
|
||||
func TestHandleGetAttachmentDeniesArchivedLinkedMemoToNonCreator(t *testing.T) {
|
||||
ts := newTestMCPService(t)
|
||||
owner := ts.createUser(t, "owner")
|
||||
other := ts.createUser(t, "other")
|
||||
|
||||
memo := ts.createMemo(t, owner.ID, store.Public, "memo")
|
||||
ts.archiveMemo(t, memo.ID)
|
||||
attachment := ts.createAttachment(t, owner.ID, &memo.ID)
|
||||
|
||||
result, err := ts.service.handleGetAttachment(withUser(context.Background(), other.ID), toolRequest("get_attachment", map[string]any{
|
||||
"name": "attachments/" + attachment.UID,
|
||||
}))
|
||||
require.NoError(t, err)
|
||||
require.True(t, result.IsError)
|
||||
require.Contains(t, firstText(t, result), "permission denied")
|
||||
}
|
||||
|
||||
func TestIsAllowedOrigin(t *testing.T) {
|
||||
ts := newTestMCPService(t)
|
||||
|
||||
t.Run("allow missing origin", func(t *testing.T) {
|
||||
req := httptest.NewRequest("POST", "http://localhost:5230/mcp", nil)
|
||||
require.True(t, ts.service.isAllowedOrigin(req))
|
||||
})
|
||||
|
||||
t.Run("allow same origin as request host", func(t *testing.T) {
|
||||
req := httptest.NewRequest("POST", "http://localhost:5230/mcp", nil)
|
||||
req.Header.Set("Origin", "http://localhost:5230")
|
||||
require.True(t, ts.service.isAllowedOrigin(req))
|
||||
})
|
||||
|
||||
t.Run("allow configured instance origin", func(t *testing.T) {
|
||||
req := httptest.NewRequest("POST", "http://127.0.0.1:5230/mcp", nil)
|
||||
req.Host = "127.0.0.1:5230"
|
||||
req.Header.Set("Origin", "https://notes.example.com")
|
||||
require.True(t, ts.service.isAllowedOrigin(req))
|
||||
})
|
||||
|
||||
t.Run("reject cross origin", func(t *testing.T) {
|
||||
req := httptest.NewRequest("POST", "http://localhost:5230/mcp", nil)
|
||||
req.Header.Set("Origin", "https://evil.example.com")
|
||||
require.False(t, ts.service.isAllowedOrigin(req))
|
||||
})
|
||||
}
|
||||
|
|
@ -48,7 +48,10 @@ func (s *MCPService) handleReadMemoResource(ctx context.Context, req mcp.ReadRes
|
|||
return nil, err
|
||||
}
|
||||
|
||||
j := storeMemoToJSON(memo)
|
||||
j, err := storeMemoToJSONWithStore(ctx, s.store, memo)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "failed to resolve memo creator")
|
||||
}
|
||||
text := formatMemoMarkdown(j)
|
||||
|
||||
return []mcp.ResourceContents{
|
||||
|
|
|
|||
|
|
@ -26,10 +26,14 @@ type attachmentJSON struct {
|
|||
Memo string `json:"memo,omitempty"`
|
||||
}
|
||||
|
||||
func storeAttachmentToJSON(a *store.Attachment) attachmentJSON {
|
||||
func storeAttachmentToJSON(ctx context.Context, stores *store.Store, a *store.Attachment) (attachmentJSON, error) {
|
||||
creator, err := lookupUsername(ctx, stores, a.CreatorID)
|
||||
if err != nil {
|
||||
return attachmentJSON{}, errors.Wrap(err, "lookup attachment creator username")
|
||||
}
|
||||
j := attachmentJSON{
|
||||
Name: "attachments/" + a.UID,
|
||||
Creator: fmt.Sprintf("users/%d", a.CreatorID),
|
||||
Creator: creator,
|
||||
CreateTime: a.CreatedTs,
|
||||
Filename: a.Filename,
|
||||
Type: a.Type,
|
||||
|
|
@ -50,7 +54,38 @@ func storeAttachmentToJSON(a *store.Attachment) attachmentJSON {
|
|||
if a.MemoUID != nil && *a.MemoUID != "" {
|
||||
j.Memo = "memos/" + *a.MemoUID
|
||||
}
|
||||
return j
|
||||
return j, nil
|
||||
}
|
||||
|
||||
func storeAttachmentToJSONWithUsernames(a *store.Attachment, usernamesByID map[int32]string) (attachmentJSON, error) {
|
||||
creator, err := lookupUsernameFromCache(usernamesByID, a.CreatorID)
|
||||
if err != nil {
|
||||
return attachmentJSON{}, errors.Wrap(err, "lookup attachment creator username from cache")
|
||||
}
|
||||
j := attachmentJSON{
|
||||
Name: "attachments/" + a.UID,
|
||||
Creator: creator,
|
||||
CreateTime: a.CreatedTs,
|
||||
Filename: a.Filename,
|
||||
Type: a.Type,
|
||||
Size: a.Size,
|
||||
}
|
||||
switch a.StorageType {
|
||||
case storepb.AttachmentStorageType_LOCAL:
|
||||
j.StorageType = "LOCAL"
|
||||
case storepb.AttachmentStorageType_S3:
|
||||
j.StorageType = "S3"
|
||||
j.ExternalLink = a.Reference
|
||||
case storepb.AttachmentStorageType_EXTERNAL:
|
||||
j.StorageType = "EXTERNAL"
|
||||
j.ExternalLink = a.Reference
|
||||
default:
|
||||
j.StorageType = "DATABASE"
|
||||
}
|
||||
if a.MemoUID != nil && *a.MemoUID != "" {
|
||||
j.Memo = "memos/" + *a.MemoUID
|
||||
}
|
||||
return j, nil
|
||||
}
|
||||
|
||||
func parseAttachmentUID(name string) (string, error) {
|
||||
|
|
@ -136,10 +171,22 @@ func (s *MCPService) handleListAttachments(ctx context.Context, req mcp.CallTool
|
|||
if hasMore {
|
||||
attachments = attachments[:pageSize]
|
||||
}
|
||||
creatorIDs := make([]int32, 0, len(attachments))
|
||||
for _, attachment := range attachments {
|
||||
creatorIDs = append(creatorIDs, attachment.CreatorID)
|
||||
}
|
||||
usernamesByID, err := preloadUsernames(ctx, s.store, creatorIDs)
|
||||
if err != nil {
|
||||
return mcp.NewToolResultError(fmt.Sprintf("failed to preload attachment creators: %v", err)), nil
|
||||
}
|
||||
|
||||
results := make([]attachmentJSON, len(attachments))
|
||||
for i, a := range attachments {
|
||||
results[i] = storeAttachmentToJSON(a)
|
||||
result, err := storeAttachmentToJSONWithUsernames(a, usernamesByID)
|
||||
if err != nil {
|
||||
return mcp.NewToolResultError(fmt.Sprintf("failed to resolve attachment creator: %v", err)), nil
|
||||
}
|
||||
results[i] = result
|
||||
}
|
||||
|
||||
type listResponse struct {
|
||||
|
|
@ -169,24 +216,15 @@ func (s *MCPService) handleGetAttachment(ctx context.Context, req mcp.CallToolRe
|
|||
return mcp.NewToolResultError("attachment not found"), nil
|
||||
}
|
||||
|
||||
// Check access: creator can always access; linked memo visibility applies otherwise.
|
||||
if attachment.CreatorID != userID {
|
||||
if attachment.MemoID != nil {
|
||||
memo, err := s.store.GetMemo(ctx, &store.FindMemo{ID: attachment.MemoID})
|
||||
if err != nil {
|
||||
return mcp.NewToolResultError(fmt.Sprintf("failed to get linked memo: %v", err)), nil
|
||||
}
|
||||
if memo != nil {
|
||||
if err := checkMemoAccess(memo, userID); err != nil {
|
||||
return mcp.NewToolResultError(err.Error()), nil
|
||||
}
|
||||
}
|
||||
} else {
|
||||
return mcp.NewToolResultError("permission denied"), nil
|
||||
}
|
||||
if err := s.checkAttachmentAccess(ctx, attachment, userID); err != nil {
|
||||
return mcp.NewToolResultError(err.Error()), nil
|
||||
}
|
||||
|
||||
out, err := marshalJSON(storeAttachmentToJSON(attachment))
|
||||
result, err := storeAttachmentToJSON(ctx, s.store, attachment)
|
||||
if err != nil {
|
||||
return mcp.NewToolResultError(fmt.Sprintf("failed to resolve attachment creator: %v", err)), nil
|
||||
}
|
||||
out, err := marshalJSON(result)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
@ -251,6 +289,9 @@ func (s *MCPService) handleLinkAttachmentToMemo(ctx context.Context, req mcp.Cal
|
|||
if memo == nil {
|
||||
return mcp.NewToolResultError("memo not found"), nil
|
||||
}
|
||||
if err := checkMemoOwnership(memo, userID); err != nil {
|
||||
return mcp.NewToolResultError(err.Error()), nil
|
||||
}
|
||||
|
||||
if err := s.store.UpdateAttachment(ctx, &store.UpdateAttachment{
|
||||
ID: attachment.ID,
|
||||
|
|
@ -264,7 +305,11 @@ func (s *MCPService) handleLinkAttachmentToMemo(ctx context.Context, req mcp.Cal
|
|||
if err != nil {
|
||||
return mcp.NewToolResultError(fmt.Sprintf("failed to fetch updated attachment: %v", err)), nil
|
||||
}
|
||||
out, err := marshalJSON(storeAttachmentToJSON(updated))
|
||||
result, err := storeAttachmentToJSON(ctx, s.store, updated)
|
||||
if err != nil {
|
||||
return mcp.NewToolResultError(fmt.Sprintf("failed to resolve attachment creator: %v", err)), nil
|
||||
}
|
||||
out, err := marshalJSON(result)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
|
|||
|
|
@ -75,7 +75,6 @@ type memoJSON struct {
|
|||
func storeMemoToJSON(m *store.Memo) memoJSON {
|
||||
j := memoJSON{
|
||||
Name: "memos/" + m.UID,
|
||||
Creator: fmt.Sprintf("users/%d", m.CreatorID),
|
||||
CreateTime: m.CreatedTs,
|
||||
UpdateTime: m.UpdatedTs,
|
||||
Content: m.Content,
|
||||
|
|
@ -103,31 +102,70 @@ func storeMemoToJSON(m *store.Memo) memoJSON {
|
|||
return j
|
||||
}
|
||||
|
||||
// checkMemoAccess returns an error if the caller cannot read memo.
|
||||
// userID == 0 means anonymous.
|
||||
func checkMemoAccess(memo *store.Memo, userID int32) error {
|
||||
switch memo.Visibility {
|
||||
case store.Protected:
|
||||
if userID == 0 {
|
||||
return errors.New("permission denied")
|
||||
}
|
||||
case store.Private:
|
||||
if memo.CreatorID != userID {
|
||||
return errors.New("permission denied")
|
||||
}
|
||||
default:
|
||||
// store.Public and any unknown visibility: allow
|
||||
func lookupUsername(ctx context.Context, stores *store.Store, userID int32) (string, error) {
|
||||
user, err := stores.GetUser(ctx, &store.FindUser{ID: &userID})
|
||||
if err != nil {
|
||||
return "", errors.Wrapf(err, "failed to get creator user %d", userID)
|
||||
}
|
||||
return nil
|
||||
if user == nil {
|
||||
return "", errors.Errorf("creator user %d not found", userID)
|
||||
}
|
||||
return "users/" + user.Username, nil
|
||||
}
|
||||
|
||||
// applyVisibilityFilter restricts find to memos the caller may see.
|
||||
func applyVisibilityFilter(find *store.FindMemo, userID int32) {
|
||||
if userID == 0 {
|
||||
find.VisibilityList = []store.Visibility{store.Public}
|
||||
} else {
|
||||
find.Filters = append(find.Filters, fmt.Sprintf(`creator_id == %d || visibility in ["PUBLIC", "PROTECTED"]`, userID))
|
||||
func preloadUsernames(ctx context.Context, stores *store.Store, userIDs []int32) (map[int32]string, error) {
|
||||
if len(userIDs) == 0 {
|
||||
return map[int32]string{}, nil
|
||||
}
|
||||
|
||||
uniqueUserIDs := make([]int32, 0, len(userIDs))
|
||||
seenUserIDs := make(map[int32]struct{}, len(userIDs))
|
||||
for _, userID := range userIDs {
|
||||
if _, seen := seenUserIDs[userID]; seen {
|
||||
continue
|
||||
}
|
||||
seenUserIDs[userID] = struct{}{}
|
||||
uniqueUserIDs = append(uniqueUserIDs, userID)
|
||||
}
|
||||
|
||||
users, err := stores.ListUsers(ctx, &store.FindUser{IDList: uniqueUserIDs})
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "failed to list creator users")
|
||||
}
|
||||
|
||||
usernamesByID := make(map[int32]string, len(users))
|
||||
for _, user := range users {
|
||||
usernamesByID[user.ID] = "users/" + user.Username
|
||||
}
|
||||
return usernamesByID, nil
|
||||
}
|
||||
|
||||
func lookupUsernameFromCache(usernamesByID map[int32]string, userID int32) (string, error) {
|
||||
username, ok := usernamesByID[userID]
|
||||
if !ok {
|
||||
return "", errors.Errorf("creator user %d not found", userID)
|
||||
}
|
||||
return username, nil
|
||||
}
|
||||
|
||||
func storeMemoToJSONWithStore(ctx context.Context, stores *store.Store, m *store.Memo) (memoJSON, error) {
|
||||
j := storeMemoToJSON(m)
|
||||
creator, err := lookupUsername(ctx, stores, m.CreatorID)
|
||||
if err != nil {
|
||||
return memoJSON{}, err
|
||||
}
|
||||
j.Creator = creator
|
||||
return j, nil
|
||||
}
|
||||
|
||||
func storeMemoToJSONWithUsernames(m *store.Memo, usernamesByID map[int32]string) (memoJSON, error) {
|
||||
j := storeMemoToJSON(m)
|
||||
creator, err := lookupUsernameFromCache(usernamesByID, m.CreatorID)
|
||||
if err != nil {
|
||||
return memoJSON{}, err
|
||||
}
|
||||
j.Creator = creator
|
||||
return j, nil
|
||||
}
|
||||
|
||||
// parseMemoUID extracts the UID from a "memos/<uid>" resource name.
|
||||
|
|
@ -185,7 +223,7 @@ func (s *MCPService) registerMemoTools(mcpSrv *mcpserver.MCPServer) {
|
|||
mcp.Description("Filter by state: NORMAL (default) or ARCHIVED"),
|
||||
),
|
||||
mcp.WithBoolean("order_by_pinned", mcp.Description("When true, pinned memos appear first (default false)")),
|
||||
mcp.WithString("filter", mcp.Description(`Optional CEL filter, e.g. content.contains("keyword") or tags.exists(t, t == "work")`)),
|
||||
mcp.WithString("filter", mcp.Description(`Optional CEL filter (supported subset of standard CEL syntax), e.g. content.contains("keyword") or tags.exists(t, t == "work")`)),
|
||||
), s.handleListMemos)
|
||||
|
||||
mcpSrv.AddTool(mcp.NewTool("get_memo",
|
||||
|
|
@ -272,7 +310,7 @@ func (s *MCPService) handleListMemos(ctx context.Context, req mcp.CallToolReques
|
|||
Offset: &offset,
|
||||
OrderByPinned: req.GetBool("order_by_pinned", false),
|
||||
}
|
||||
applyVisibilityFilter(find, userID)
|
||||
applyVisibilityFilter(find, userID, rowStatus)
|
||||
if filter := req.GetString("filter", ""); filter != "" {
|
||||
find.Filters = append(find.Filters, filter)
|
||||
}
|
||||
|
|
@ -286,10 +324,22 @@ func (s *MCPService) handleListMemos(ctx context.Context, req mcp.CallToolReques
|
|||
if hasMore {
|
||||
memos = memos[:pageSize]
|
||||
}
|
||||
creatorIDs := make([]int32, 0, len(memos))
|
||||
for _, memo := range memos {
|
||||
creatorIDs = append(creatorIDs, memo.CreatorID)
|
||||
}
|
||||
usernamesByID, err := preloadUsernames(ctx, s.store, creatorIDs)
|
||||
if err != nil {
|
||||
return mcp.NewToolResultError(fmt.Sprintf("failed to preload memo creators: %v", err)), nil
|
||||
}
|
||||
|
||||
results := make([]memoJSON, len(memos))
|
||||
for i, m := range memos {
|
||||
results[i] = storeMemoToJSON(m)
|
||||
result, err := storeMemoToJSONWithUsernames(m, usernamesByID)
|
||||
if err != nil {
|
||||
return mcp.NewToolResultError(fmt.Sprintf("failed to resolve memo creator: %v", err)), nil
|
||||
}
|
||||
results[i] = result
|
||||
}
|
||||
|
||||
type listResponse struct {
|
||||
|
|
@ -322,7 +372,11 @@ func (s *MCPService) handleGetMemo(ctx context.Context, req mcp.CallToolRequest)
|
|||
return mcp.NewToolResultError(err.Error()), nil
|
||||
}
|
||||
|
||||
out, err := marshalJSON(storeMemoToJSON(memo))
|
||||
result, err := storeMemoToJSONWithStore(ctx, s.store, memo)
|
||||
if err != nil {
|
||||
return mcp.NewToolResultError(fmt.Sprintf("failed to resolve memo creator: %v", err)), nil
|
||||
}
|
||||
out, err := marshalJSON(result)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
@ -355,7 +409,11 @@ func (s *MCPService) handleCreateMemo(ctx context.Context, req mcp.CallToolReque
|
|||
return mcp.NewToolResultError(fmt.Sprintf("failed to create memo: %v", err)), nil
|
||||
}
|
||||
|
||||
out, err := marshalJSON(storeMemoToJSON(memo))
|
||||
result, err := storeMemoToJSONWithStore(ctx, s.store, memo)
|
||||
if err != nil {
|
||||
return mcp.NewToolResultError(fmt.Sprintf("failed to resolve memo creator: %v", err)), nil
|
||||
}
|
||||
out, err := marshalJSON(result)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
@ -380,8 +438,8 @@ func (s *MCPService) handleUpdateMemo(ctx context.Context, req mcp.CallToolReque
|
|||
if memo == nil {
|
||||
return mcp.NewToolResultError("memo not found"), nil
|
||||
}
|
||||
if memo.CreatorID != userID {
|
||||
return mcp.NewToolResultError("permission denied"), nil
|
||||
if err := checkMemoOwnership(memo, userID); err != nil {
|
||||
return mcp.NewToolResultError(err.Error()), nil
|
||||
}
|
||||
|
||||
update := &store.UpdateMemo{ID: memo.ID}
|
||||
|
|
@ -419,7 +477,11 @@ func (s *MCPService) handleUpdateMemo(ctx context.Context, req mcp.CallToolReque
|
|||
return mcp.NewToolResultError(fmt.Sprintf("failed to fetch updated memo: %v", err)), nil
|
||||
}
|
||||
|
||||
out, err := marshalJSON(storeMemoToJSON(updated))
|
||||
result, err := storeMemoToJSONWithStore(ctx, s.store, updated)
|
||||
if err != nil {
|
||||
return mcp.NewToolResultError(fmt.Sprintf("failed to resolve memo creator: %v", err)), nil
|
||||
}
|
||||
out, err := marshalJSON(result)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
@ -444,8 +506,8 @@ func (s *MCPService) handleDeleteMemo(ctx context.Context, req mcp.CallToolReque
|
|||
if memo == nil {
|
||||
return mcp.NewToolResultError("memo not found"), nil
|
||||
}
|
||||
if memo.CreatorID != userID {
|
||||
return mcp.NewToolResultError("permission denied"), nil
|
||||
if err := checkMemoOwnership(memo, userID); err != nil {
|
||||
return mcp.NewToolResultError(err.Error()), nil
|
||||
}
|
||||
|
||||
if err := s.store.DeleteMemo(ctx, &store.DeleteMemo{ID: memo.ID}); err != nil {
|
||||
|
|
@ -472,16 +534,28 @@ func (s *MCPService) handleSearchMemos(ctx context.Context, req mcp.CallToolRequ
|
|||
Offset: &zero,
|
||||
Filters: []string{fmt.Sprintf(`content.contains(%q)`, query)},
|
||||
}
|
||||
applyVisibilityFilter(find, userID)
|
||||
applyVisibilityFilter(find, userID, find.RowStatus)
|
||||
|
||||
memos, err := s.store.ListMemos(ctx, find)
|
||||
if err != nil {
|
||||
return mcp.NewToolResultError(fmt.Sprintf("failed to search memos: %v", err)), nil
|
||||
}
|
||||
creatorIDs := make([]int32, 0, len(memos))
|
||||
for _, memo := range memos {
|
||||
creatorIDs = append(creatorIDs, memo.CreatorID)
|
||||
}
|
||||
usernamesByID, err := preloadUsernames(ctx, s.store, creatorIDs)
|
||||
if err != nil {
|
||||
return mcp.NewToolResultError(fmt.Sprintf("failed to preload memo creators: %v", err)), nil
|
||||
}
|
||||
|
||||
results := make([]memoJSON, len(memos))
|
||||
for i, m := range memos {
|
||||
results[i] = storeMemoToJSON(m)
|
||||
result, err := storeMemoToJSONWithUsernames(m, usernamesByID)
|
||||
if err != nil {
|
||||
return mcp.NewToolResultError(fmt.Sprintf("failed to resolve memo creator: %v", err)), nil
|
||||
}
|
||||
results[i] = result
|
||||
}
|
||||
out, err := marshalJSON(results)
|
||||
if err != nil {
|
||||
|
|
@ -531,11 +605,25 @@ func (s *MCPService) handleListMemoComments(ctx context.Context, req mcp.CallToo
|
|||
if err != nil {
|
||||
return mcp.NewToolResultError(fmt.Sprintf("failed to list comments: %v", err)), nil
|
||||
}
|
||||
creatorIDs := make([]int32, 0, len(memos))
|
||||
for _, memo := range memos {
|
||||
if checkMemoAccess(memo, userID) == nil {
|
||||
creatorIDs = append(creatorIDs, memo.CreatorID)
|
||||
}
|
||||
}
|
||||
usernamesByID, err := preloadUsernames(ctx, s.store, creatorIDs)
|
||||
if err != nil {
|
||||
return mcp.NewToolResultError(fmt.Sprintf("failed to preload memo creators: %v", err)), nil
|
||||
}
|
||||
|
||||
results := make([]memoJSON, 0, len(memos))
|
||||
for _, m := range memos {
|
||||
if checkMemoAccess(m, userID) == nil {
|
||||
results = append(results, storeMemoToJSON(m))
|
||||
result, err := storeMemoToJSONWithUsernames(m, usernamesByID)
|
||||
if err != nil {
|
||||
return mcp.NewToolResultError(fmt.Sprintf("failed to resolve memo creator: %v", err)), nil
|
||||
}
|
||||
results = append(results, result)
|
||||
}
|
||||
}
|
||||
out, err := marshalJSON(results)
|
||||
|
|
@ -591,7 +679,11 @@ func (s *MCPService) handleCreateMemoComment(ctx context.Context, req mcp.CallTo
|
|||
return mcp.NewToolResultError(fmt.Sprintf("failed to link comment: %v", err)), nil
|
||||
}
|
||||
|
||||
out, err := marshalJSON(storeMemoToJSON(comment))
|
||||
result, err := storeMemoToJSONWithStore(ctx, s.store, comment)
|
||||
if err != nil {
|
||||
return mcp.NewToolResultError(fmt.Sprintf("failed to resolve memo creator: %v", err)), nil
|
||||
}
|
||||
out, err := marshalJSON(result)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
|
|||
|
|
@ -60,12 +60,24 @@ func (s *MCPService) handleListReactions(ctx context.Context, req mcp.CallToolRe
|
|||
if err != nil {
|
||||
return mcp.NewToolResultError(fmt.Sprintf("failed to list reactions: %v", err)), nil
|
||||
}
|
||||
creatorIDs := make([]int32, 0, len(reactions))
|
||||
for _, reaction := range reactions {
|
||||
creatorIDs = append(creatorIDs, reaction.CreatorID)
|
||||
}
|
||||
usernamesByID, err := preloadUsernames(ctx, s.store, creatorIDs)
|
||||
if err != nil {
|
||||
return mcp.NewToolResultError(fmt.Sprintf("failed to preload reaction creators: %v", err)), nil
|
||||
}
|
||||
|
||||
results := make([]reactionJSON, len(reactions))
|
||||
for i, r := range reactions {
|
||||
creator, err := lookupUsernameFromCache(usernamesByID, r.CreatorID)
|
||||
if err != nil {
|
||||
return mcp.NewToolResultError(fmt.Sprintf("failed to resolve reaction creator: %v", err)), nil
|
||||
}
|
||||
results[i] = reactionJSON{
|
||||
ID: r.ID,
|
||||
Creator: fmt.Sprintf("users/%d", r.CreatorID),
|
||||
Creator: creator,
|
||||
ReactionType: r.ReactionType,
|
||||
CreateTime: r.CreatedTs,
|
||||
}
|
||||
|
|
@ -130,9 +142,13 @@ func (s *MCPService) handleUpsertReaction(ctx context.Context, req mcp.CallToolR
|
|||
return mcp.NewToolResultError(fmt.Sprintf("failed to upsert reaction: %v", err)), nil
|
||||
}
|
||||
|
||||
creator, err := lookupUsername(ctx, s.store, reaction.CreatorID)
|
||||
if err != nil {
|
||||
return mcp.NewToolResultError(fmt.Sprintf("failed to resolve reaction creator: %v", err)), nil
|
||||
}
|
||||
out, err := marshalJSON(reactionJSON{
|
||||
ID: reaction.ID,
|
||||
Creator: fmt.Sprintf("users/%d", reaction.CreatorID),
|
||||
Creator: creator,
|
||||
ReactionType: reaction.ReactionType,
|
||||
CreateTime: reaction.CreatedTs,
|
||||
})
|
||||
|
|
|
|||
|
|
@ -7,6 +7,7 @@ import (
|
|||
"github.com/mark3labs/mcp-go/mcp"
|
||||
mcpserver "github.com/mark3labs/mcp-go/server"
|
||||
|
||||
"github.com/usememos/memos/server/auth"
|
||||
"github.com/usememos/memos/store"
|
||||
)
|
||||
|
||||
|
|
@ -40,6 +41,8 @@ func (s *MCPService) registerRelationTools(mcpSrv *mcpserver.MCPServer) {
|
|||
}
|
||||
|
||||
func (s *MCPService) handleListMemoRelations(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) {
|
||||
userID := auth.GetUserID(ctx)
|
||||
|
||||
uid, err := parseMemoUID(req.GetString("name", ""))
|
||||
if err != nil {
|
||||
return mcp.NewToolResultError(err.Error()), nil
|
||||
|
|
@ -52,6 +55,9 @@ func (s *MCPService) handleListMemoRelations(ctx context.Context, req mcp.CallTo
|
|||
if memo == nil {
|
||||
return mcp.NewToolResultError("memo not found"), nil
|
||||
}
|
||||
if err := checkMemoAccess(memo, userID); err != nil {
|
||||
return mcp.NewToolResultError(err.Error()), nil
|
||||
}
|
||||
|
||||
find := &store.FindMemoRelation{
|
||||
MemoIDList: []int32{memo.ID},
|
||||
|
|
@ -85,21 +91,24 @@ func (s *MCPService) handleListMemoRelations(ctx context.Context, req mcp.CallTo
|
|||
if err != nil {
|
||||
return mcp.NewToolResultError(fmt.Sprintf("failed to resolve memos: %v", err)), nil
|
||||
}
|
||||
uidByID := make(map[int32]string, len(memos))
|
||||
memoByID := make(map[int32]*store.Memo, len(memos))
|
||||
for _, m := range memos {
|
||||
uidByID[m.ID] = m.UID
|
||||
memoByID[m.ID] = m
|
||||
}
|
||||
|
||||
results := make([]relationJSON, 0, len(relations))
|
||||
for _, r := range relations {
|
||||
memoUID, ok1 := uidByID[r.MemoID]
|
||||
relatedUID, ok2 := uidByID[r.RelatedMemoID]
|
||||
srcMemo, ok1 := memoByID[r.MemoID]
|
||||
relatedMemo, ok2 := memoByID[r.RelatedMemoID]
|
||||
if !ok1 || !ok2 {
|
||||
continue
|
||||
}
|
||||
if checkMemoAccess(srcMemo, userID) != nil || checkMemoAccess(relatedMemo, userID) != nil {
|
||||
continue
|
||||
}
|
||||
results = append(results, relationJSON{
|
||||
Memo: "memos/" + memoUID,
|
||||
RelatedMemo: "memos/" + relatedUID,
|
||||
Memo: "memos/" + srcMemo.UID,
|
||||
RelatedMemo: "memos/" + relatedMemo.UID,
|
||||
Type: string(r.Type),
|
||||
})
|
||||
}
|
||||
|
|
@ -133,7 +142,7 @@ func (s *MCPService) handleCreateMemoRelation(ctx context.Context, req mcp.CallT
|
|||
if srcMemo == nil {
|
||||
return mcp.NewToolResultError("source memo not found"), nil
|
||||
}
|
||||
if srcMemo.CreatorID != userID {
|
||||
if !hasMemoOwnership(srcMemo, userID) {
|
||||
return mcp.NewToolResultError("permission denied: must own the source memo"), nil
|
||||
}
|
||||
|
||||
|
|
@ -144,6 +153,9 @@ func (s *MCPService) handleCreateMemoRelation(ctx context.Context, req mcp.CallT
|
|||
if dstMemo == nil {
|
||||
return mcp.NewToolResultError("related memo not found"), nil
|
||||
}
|
||||
if err := checkMemoAccess(dstMemo, userID); err != nil {
|
||||
return mcp.NewToolResultError(err.Error()), nil
|
||||
}
|
||||
|
||||
relation, err := s.store.UpsertMemoRelation(ctx, &store.MemoRelation{
|
||||
MemoID: srcMemo.ID,
|
||||
|
|
@ -187,7 +199,7 @@ func (s *MCPService) handleDeleteMemoRelation(ctx context.Context, req mcp.CallT
|
|||
if srcMemo == nil {
|
||||
return mcp.NewToolResultError("source memo not found"), nil
|
||||
}
|
||||
if srcMemo.CreatorID != userID {
|
||||
if !hasMemoOwnership(srcMemo, userID) {
|
||||
return mcp.NewToolResultError("permission denied: must own the source memo"), nil
|
||||
}
|
||||
|
||||
|
|
@ -198,6 +210,9 @@ func (s *MCPService) handleDeleteMemoRelation(ctx context.Context, req mcp.CallT
|
|||
if dstMemo == nil {
|
||||
return mcp.NewToolResultError("related memo not found"), nil
|
||||
}
|
||||
if err := checkMemoAccess(dstMemo, userID); err != nil {
|
||||
return mcp.NewToolResultError(err.Error()), nil
|
||||
}
|
||||
|
||||
refType := store.MemoRelationReference
|
||||
if err := s.store.DeleteMemoRelation(ctx, &store.DeleteMemoRelation{
|
||||
|
|
|
|||
|
|
@ -32,7 +32,7 @@ func (s *MCPService) handleListTags(ctx context.Context, _ mcp.CallToolRequest)
|
|||
ExcludeContent: true,
|
||||
RowStatus: &rowStatus,
|
||||
}
|
||||
applyVisibilityFilter(find, userID)
|
||||
applyVisibilityFilter(find, userID, find.RowStatus)
|
||||
|
||||
memos, err := s.store.ListMemos(ctx, find)
|
||||
if err != nil {
|
||||
|
|
|
|||
|
|
@ -145,6 +145,7 @@ func (d *DB) ListMemos(ctx context.Context, find *store.FindMemo) ([]*store.Memo
|
|||
}
|
||||
|
||||
query := "SELECT " + strings.Join(fields, ", ") + " FROM `memo`" + " " +
|
||||
"LEFT JOIN `user` AS `memo_creator` ON `memo`.`creator_id` = `memo_creator`.`id`" + " " +
|
||||
"LEFT JOIN `memo_relation` ON `memo`.`id` = `memo_relation`.`memo_id` AND `memo_relation`.`type` = 'COMMENT'" + " " +
|
||||
"LEFT JOIN `memo` AS `parent_memo` ON `memo_relation`.`related_memo_id` = `parent_memo`.`id`" + " " +
|
||||
"WHERE " + strings.Join(where, " AND ") + " " +
|
||||
|
|
|
|||
|
|
@ -131,6 +131,7 @@ func (d *DB) ListMemos(ctx context.Context, find *store.FindMemo) ([]*store.Memo
|
|||
|
||||
query := `SELECT ` + strings.Join(fields, ", ") + `
|
||||
FROM memo
|
||||
LEFT JOIN "user" AS memo_creator ON memo.creator_id = memo_creator.id
|
||||
LEFT JOIN memo_relation ON memo.id = memo_relation.memo_id AND memo_relation.type = 'COMMENT'
|
||||
LEFT JOIN memo AS parent_memo ON memo_relation.related_memo_id = parent_memo.id
|
||||
WHERE ` + strings.Join(where, " AND ") + `
|
||||
|
|
|
|||
|
|
@ -137,6 +137,7 @@ func (d *DB) ListMemos(ctx context.Context, find *store.FindMemo) ([]*store.Memo
|
|||
}
|
||||
|
||||
query := "SELECT " + strings.Join(fields, ", ") + "FROM `memo` " +
|
||||
"LEFT JOIN `user` AS `memo_creator` ON `memo`.`creator_id` = `memo_creator`.`id` " +
|
||||
"LEFT JOIN `memo_relation` ON `memo`.`id` = `memo_relation`.`memo_id` AND `memo_relation`.`type` = \"COMMENT\" " +
|
||||
"LEFT JOIN `memo` AS `parent_memo` ON `memo_relation`.`related_memo_id` = `parent_memo`.`id` " +
|
||||
"WHERE " + strings.Join(where, " AND ") + " " +
|
||||
|
|
|
|||
|
|
@ -219,7 +219,7 @@ func (s *Store) GetInstanceNotificationSetting(ctx context.Context) (*storepb.In
|
|||
const (
|
||||
defaultInstanceStorageType = storepb.InstanceStorageSetting_LOCAL
|
||||
defaultInstanceUploadSizeLimitMb = 30
|
||||
defaultInstanceFilepathTemplate = "assets/{timestamp}_{filename}"
|
||||
defaultInstanceFilepathTemplate = "assets/{timestamp}_{uuid}_{filename}"
|
||||
)
|
||||
|
||||
func (s *Store) GetInstanceStorageSetting(ctx context.Context) (*storepb.InstanceStorageSetting, error) {
|
||||
|
|
|
|||
|
|
@ -4,6 +4,8 @@ import (
|
|||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/url"
|
||||
"os"
|
||||
"strings"
|
||||
"sync"
|
||||
|
|
@ -12,6 +14,7 @@ import (
|
|||
"time"
|
||||
|
||||
"github.com/docker/docker/api/types/container"
|
||||
mysqldriver "github.com/go-sql-driver/mysql"
|
||||
"github.com/pkg/errors"
|
||||
"github.com/testcontainers/testcontainers-go"
|
||||
"github.com/testcontainers/testcontainers-go/modules/mysql"
|
||||
|
|
@ -20,7 +23,6 @@ import (
|
|||
"github.com/testcontainers/testcontainers-go/wait"
|
||||
|
||||
// Database drivers for connection verification.
|
||||
_ "github.com/go-sql-driver/mysql"
|
||||
_ "github.com/lib/pq"
|
||||
)
|
||||
|
||||
|
|
@ -31,6 +33,9 @@ const (
|
|||
// Memos container settings for migration testing.
|
||||
MemosDockerImage = "neosmemo/memos"
|
||||
StableMemosVersion = "stable" // Always points to the latest stable release
|
||||
|
||||
mysqlNetworkAlias = "memos-mysql"
|
||||
postgresNetworkAlias = "memos-postgres"
|
||||
)
|
||||
|
||||
var (
|
||||
|
|
@ -62,12 +67,23 @@ func getTestNetwork(ctx context.Context) (*testcontainers.DockerNetwork, error)
|
|||
return testDockerNetwork.Load(), networkErr
|
||||
}
|
||||
|
||||
func requireTestNetwork(ctx context.Context) (*testcontainers.DockerNetwork, error) {
|
||||
nw, err := getTestNetwork(ctx)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "failed to create test network")
|
||||
}
|
||||
if nw == nil {
|
||||
return nil, errors.New("test network is unavailable")
|
||||
}
|
||||
return nw, nil
|
||||
}
|
||||
|
||||
// GetMySQLDSN starts a MySQL container (if not already running) and creates a fresh database for this test.
|
||||
func GetMySQLDSN(t *testing.T) string {
|
||||
ctx := context.Background()
|
||||
|
||||
mysqlOnce.Do(func() {
|
||||
nw, err := getTestNetwork(ctx)
|
||||
nw, err := requireTestNetwork(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create test network: %v", err)
|
||||
}
|
||||
|
|
@ -86,7 +102,7 @@ func GetMySQLDSN(t *testing.T) string {
|
|||
wait.ForListeningPort("3306/tcp"),
|
||||
).WithDeadline(120*time.Second),
|
||||
),
|
||||
network.WithNetwork(nil, nw),
|
||||
network.WithNetwork([]string{mysqlNetworkAlias}, nw),
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to start MySQL container: %v", err)
|
||||
|
|
@ -167,7 +183,7 @@ func GetPostgresDSN(t *testing.T) string {
|
|||
ctx := context.Background()
|
||||
|
||||
postgresOnce.Do(func() {
|
||||
nw, err := getTestNetwork(ctx)
|
||||
nw, err := requireTestNetwork(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create test network: %v", err)
|
||||
}
|
||||
|
|
@ -183,7 +199,7 @@ func GetPostgresDSN(t *testing.T) string {
|
|||
wait.ForListeningPort("5432/tcp"),
|
||||
).WithDeadline(120*time.Second),
|
||||
),
|
||||
network.WithNetwork(nil, nw),
|
||||
network.WithNetwork([]string{postgresNetworkAlias}, nw),
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to start PostgreSQL container: %v", err)
|
||||
|
|
@ -264,6 +280,11 @@ func StartMemosContainer(ctx context.Context, cfg MemosContainerConfig) (testcon
|
|||
"MEMOS_MODE": "prod",
|
||||
}
|
||||
|
||||
nw, err := requireTestNetwork(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var opts []testcontainers.ContainerCustomizer
|
||||
|
||||
switch cfg.Driver {
|
||||
|
|
@ -272,6 +293,12 @@ func StartMemosContainer(ctx context.Context, cfg MemosContainerConfig) (testcon
|
|||
opts = append(opts, testcontainers.WithHostConfigModifier(func(hc *container.HostConfig) {
|
||||
hc.Binds = append(hc.Binds, fmt.Sprintf("%s:%s", cfg.DataDir, "/var/opt/memos"))
|
||||
}))
|
||||
case "mysql", "postgres":
|
||||
if cfg.DSN == "" {
|
||||
return nil, errors.Errorf("dsn is required for %s migration testing", cfg.Driver)
|
||||
}
|
||||
env["MEMOS_DRIVER"] = cfg.Driver
|
||||
env["MEMOS_DSN"] = cfg.DSN
|
||||
default:
|
||||
return nil, errors.Errorf("unsupported driver for migration testing: %s", cfg.Driver)
|
||||
}
|
||||
|
|
@ -303,6 +330,7 @@ func StartMemosContainer(ctx context.Context, cfg MemosContainerConfig) (testcon
|
|||
}
|
||||
|
||||
// Apply options
|
||||
opts = append(opts, network.WithNetwork(nil, nw))
|
||||
for _, opt := range opts {
|
||||
if err := opt.Customize(&genericReq); err != nil {
|
||||
return nil, errors.Wrap(err, "failed to apply container option")
|
||||
|
|
@ -316,3 +344,27 @@ func StartMemosContainer(ctx context.Context, cfg MemosContainerConfig) (testcon
|
|||
|
||||
return ctr, nil
|
||||
}
|
||||
|
||||
func getContainerDSN(driver, hostDSN string) (string, error) {
|
||||
switch driver {
|
||||
case "mysql":
|
||||
cfg, err := mysqldriver.ParseDSN(hostDSN)
|
||||
if err != nil {
|
||||
return "", errors.Wrap(err, "failed to parse mysql dsn")
|
||||
}
|
||||
cfg.Net = "tcp"
|
||||
cfg.Addr = net.JoinHostPort(mysqlNetworkAlias, "3306")
|
||||
return cfg.FormatDSN(), nil
|
||||
case "postgres":
|
||||
u, err := url.Parse(hostDSN)
|
||||
if err != nil {
|
||||
return "", errors.Wrap(err, "failed to parse postgres dsn")
|
||||
}
|
||||
u.Host = net.JoinHostPort(postgresNetworkAlias, "5432")
|
||||
return u.String(), nil
|
||||
case "sqlite":
|
||||
return hostDSN, nil
|
||||
default:
|
||||
return "", errors.Errorf("unsupported driver for container dsn: %s", driver)
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -196,7 +196,7 @@ func TestInstanceSettingStorageSetting(t *testing.T) {
|
|||
require.NotNil(t, storageSetting)
|
||||
require.Equal(t, storepb.InstanceStorageSetting_LOCAL, storageSetting.StorageType)
|
||||
require.Equal(t, int64(30), storageSetting.UploadSizeLimitMb)
|
||||
require.Equal(t, "assets/{timestamp}_{filename}", storageSetting.FilepathTemplate)
|
||||
require.Equal(t, "assets/{timestamp}_{uuid}_{filename}", storageSetting.FilepathTemplate)
|
||||
|
||||
// Set custom storage setting
|
||||
_, err = ts.UpsertInstanceSetting(ctx, &storepb.InstanceSetting{
|
||||
|
|
@ -257,6 +257,34 @@ func TestInstanceSettingTagsSetting(t *testing.T) {
|
|||
ts.Close()
|
||||
}
|
||||
|
||||
func TestInstanceSettingTagsSettingWithoutColor(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := context.Background()
|
||||
ts := NewTestingStore(ctx, t)
|
||||
|
||||
_, err := ts.UpsertInstanceSetting(ctx, &storepb.InstanceSetting{
|
||||
Key: storepb.InstanceSettingKey_TAGS,
|
||||
Value: &storepb.InstanceSetting_TagsSetting{
|
||||
TagsSetting: &storepb.InstanceTagsSetting{
|
||||
Tags: map[string]*storepb.InstanceTagMetadata{
|
||||
"spoiler": {
|
||||
BlurContent: true,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
tagsSetting, err := ts.GetInstanceTagsSetting(ctx)
|
||||
require.NoError(t, err)
|
||||
require.Contains(t, tagsSetting.Tags, "spoiler")
|
||||
require.Nil(t, tagsSetting.Tags["spoiler"].GetBackgroundColor())
|
||||
require.True(t, tagsSetting.Tags["spoiler"].GetBlurContent())
|
||||
|
||||
ts.Close()
|
||||
}
|
||||
|
||||
func TestInstanceSettingNotificationSetting(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := context.Background()
|
||||
|
|
|
|||
|
|
@ -184,6 +184,32 @@ func TestMemoFilterPinnedPredicate(t *testing.T) {
|
|||
require.True(t, memos[0].Pinned)
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// Creator Field Tests
|
||||
// Schema: creator (string resource name), creator_id (int, ==, !=)
|
||||
// =============================================================================
|
||||
|
||||
func TestMemoFilterCreatorEquals(t *testing.T) {
|
||||
t.Parallel()
|
||||
tc := NewMemoFilterTestContext(t)
|
||||
defer tc.Close()
|
||||
|
||||
user2, err := tc.Store.CreateUser(tc.Ctx, &store.User{
|
||||
Username: "user2",
|
||||
Role: store.RoleUser,
|
||||
Email: "user2@example.com",
|
||||
Nickname: "User 2",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
tc.CreateMemo(NewMemoBuilder("memo-user1", tc.User.ID).Content("User 1 memo"))
|
||||
tc.CreateMemo(NewMemoBuilder("memo-user2", user2.ID).Content("User 2 memo"))
|
||||
|
||||
memos := tc.ListWithFilter(`creator == "users/` + tc.User.Username + `"`)
|
||||
require.Len(t, memos, 1)
|
||||
require.Equal(t, tc.User.ID, memos[0].CreatorID)
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// Creator ID Field Tests
|
||||
// Schema: creator_id (int, ==, !=)
|
||||
|
|
@ -704,6 +730,31 @@ func TestMemoFilterTagsExistsContains(t *testing.T) {
|
|||
require.Len(t, memos, 1, "Should find 1 non-todo memo")
|
||||
}
|
||||
|
||||
func TestMemoFilterTagsExistsEquals(t *testing.T) {
|
||||
t.Parallel()
|
||||
tc := NewMemoFilterTestContext(t)
|
||||
defer tc.Close()
|
||||
|
||||
tc.CreateMemo(NewMemoBuilder("memo-1231", tc.User.ID).
|
||||
Content("Memo with exact numeric tag").
|
||||
Tags("1231", "project"))
|
||||
|
||||
tc.CreateMemo(NewMemoBuilder("memo-1231-suffix", tc.User.ID).
|
||||
Content("Memo with related tag").
|
||||
Tags("tag/1231", "other"))
|
||||
|
||||
tc.CreateMemo(NewMemoBuilder("memo-other", tc.User.ID).
|
||||
Content("Memo with different tag").
|
||||
Tags("9999"))
|
||||
|
||||
memos := tc.ListWithFilter(`tags.exists(t, t == "1231")`)
|
||||
require.Len(t, memos, 1, "Should find only the memo with exact matching tag")
|
||||
require.Equal(t, "memo-1231", memos[0].UID)
|
||||
|
||||
memos = tc.ListWithFilter(`!tags.exists(t, t == "1231")`)
|
||||
require.Len(t, memos, 2, "Should exclude only the memo with exact matching tag")
|
||||
}
|
||||
|
||||
func TestMemoFilterTagsExistsEndsWith(t *testing.T) {
|
||||
t.Parallel()
|
||||
tc := NewMemoFilterTestContext(t)
|
||||
|
|
|
|||
|
|
@ -0,0 +1,274 @@
|
|||
package test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"os"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
storepb "github.com/usememos/memos/proto/gen/store"
|
||||
"github.com/usememos/memos/store"
|
||||
)
|
||||
|
||||
func TestMigrationFromV0262PreservesLegacyData(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("skipping container-based upgrade test in short mode")
|
||||
}
|
||||
if os.Getenv("SKIP_CONTAINER_TESTS") == "1" {
|
||||
t.Skip("skipping container-based test (SKIP_CONTAINER_TESTS=1)")
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
driver := getDriverFromEnv()
|
||||
|
||||
cfg, hostDSN := prepareV0262MigrationTest(t, driver)
|
||||
t.Logf("Starting Memos %s container for %s schema bootstrap...", cfg.Version, driver)
|
||||
container, err := StartMemosContainer(ctx, cfg)
|
||||
require.NoError(t, err, "failed to start v0.26.2 memos container")
|
||||
t.Cleanup(func() {
|
||||
if container != nil {
|
||||
_ = container.Terminate(ctx)
|
||||
}
|
||||
})
|
||||
|
||||
legacyStore := NewTestingStoreWithDSN(ctx, t, driver, hostDSN)
|
||||
require.Eventually(t, func() bool {
|
||||
setting, err := legacyStore.GetInstanceBasicSetting(ctx)
|
||||
return err == nil && setting != nil && setting.SchemaVersion != ""
|
||||
}, 45*time.Second, 500*time.Millisecond, "legacy schema should be initialized by old container")
|
||||
|
||||
settingBeforeSeed, err := legacyStore.GetInstanceBasicSetting(ctx)
|
||||
require.NoError(t, err)
|
||||
t.Logf("Legacy schema version before migration: %s", settingBeforeSeed.SchemaVersion)
|
||||
|
||||
err = container.Terminate(ctx)
|
||||
require.NoError(t, err, "failed to stop v0.26.2 memos container")
|
||||
container = nil
|
||||
|
||||
db := openMigrationSQLDB(t, driver, hostDSN)
|
||||
defer db.Close()
|
||||
|
||||
seedLegacyMigrationData(ctx, t, driver, db)
|
||||
|
||||
count, err := countSystemSetting(ctx, db, "STORAGE")
|
||||
require.NoError(t, err)
|
||||
require.Zero(t, count, "v0.26.2 database should not have a STORAGE setting before migration")
|
||||
|
||||
ts := NewTestingStoreWithDSN(ctx, t, driver, hostDSN)
|
||||
err = ts.Migrate(ctx)
|
||||
require.NoError(t, err, "migration from v0.26.2 should succeed for %s", driver)
|
||||
|
||||
currentVersion, err := ts.GetCurrentSchemaVersion()
|
||||
require.NoError(t, err)
|
||||
currentSetting, err := ts.GetInstanceBasicSetting(ctx)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, currentVersion, currentSetting.SchemaVersion, "schema version should be updated")
|
||||
|
||||
storageSetting, err := ts.GetInstanceStorageSetting(ctx)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, storepb.InstanceStorageSetting_DATABASE, storageSetting.StorageType, "existing installs should stay on DATABASE storage")
|
||||
|
||||
idps, err := ts.ListIdentityProviders(ctx, &store.FindIdentityProvider{})
|
||||
require.NoError(t, err)
|
||||
require.Len(t, idps, 2)
|
||||
idpUIDsByName := map[string]string{}
|
||||
for _, idp := range idps {
|
||||
idpUIDsByName[idp.Name] = idp.Uid
|
||||
}
|
||||
require.Equal(t, "00000191", idpUIDsByName["Legacy Google"])
|
||||
require.Equal(t, "00000192", idpUIDsByName["Legacy GitHub"])
|
||||
|
||||
inboxes, err := ts.ListInboxes(ctx, &store.FindInbox{})
|
||||
require.NoError(t, err)
|
||||
require.Len(t, inboxes, 1)
|
||||
require.NotNil(t, inboxes[0].Message)
|
||||
require.Equal(t, storepb.InboxMessage_MEMO_COMMENT, inboxes[0].Message.Type)
|
||||
require.Equal(t, int32(102), inboxes[0].Message.GetMemoComment().MemoId)
|
||||
require.Equal(t, int32(101), inboxes[0].Message.GetMemoComment().RelatedMemoId)
|
||||
|
||||
activityExists, err := tableExists(ctx, db, driver, "activity")
|
||||
require.NoError(t, err)
|
||||
require.False(t, activityExists, "activity table should be removed after migration")
|
||||
|
||||
memoShareExists, err := tableExists(ctx, db, driver, "memo_share")
|
||||
require.NoError(t, err)
|
||||
require.True(t, memoShareExists, "memo_share table should be created")
|
||||
|
||||
share, err := ts.CreateMemoShare(ctx, &store.MemoShare{
|
||||
UID: "post-upgrade-share",
|
||||
MemoID: 101,
|
||||
CreatorID: 11,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "post-upgrade-share", share.UID)
|
||||
|
||||
postUpgradeUser, err := createTestingUserWithRole(ctx, ts, "postupgrade", store.RoleUser)
|
||||
require.NoError(t, err)
|
||||
postUpgradeMemo, err := ts.CreateMemo(ctx, &store.Memo{
|
||||
UID: "post-upgrade-memo-v0262",
|
||||
CreatorID: postUpgradeUser.ID,
|
||||
Content: "created after v0.26.2 migration",
|
||||
Visibility: store.Public,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "created after v0.26.2 migration", postUpgradeMemo.Content)
|
||||
}
|
||||
|
||||
func prepareV0262MigrationTest(t *testing.T, driver string) (MemosContainerConfig, string) {
|
||||
t.Helper()
|
||||
|
||||
const version = "0.26.2"
|
||||
|
||||
switch driver {
|
||||
case "sqlite":
|
||||
dataDir := t.TempDir()
|
||||
return MemosContainerConfig{
|
||||
Version: version,
|
||||
Driver: driver,
|
||||
DataDir: dataDir,
|
||||
}, fmt.Sprintf("%s/memos_prod.db", dataDir)
|
||||
case "mysql":
|
||||
hostDSN := GetMySQLDSN(t)
|
||||
containerDSN, err := getContainerDSN(driver, hostDSN)
|
||||
require.NoError(t, err)
|
||||
return MemosContainerConfig{
|
||||
Version: version,
|
||||
Driver: driver,
|
||||
DSN: containerDSN,
|
||||
}, hostDSN
|
||||
case "postgres":
|
||||
hostDSN := GetPostgresDSN(t)
|
||||
containerDSN, err := getContainerDSN(driver, hostDSN)
|
||||
require.NoError(t, err)
|
||||
return MemosContainerConfig{
|
||||
Version: version,
|
||||
Driver: driver,
|
||||
DSN: containerDSN,
|
||||
}, hostDSN
|
||||
default:
|
||||
t.Fatalf("unsupported driver: %s", driver)
|
||||
return MemosContainerConfig{}, ""
|
||||
}
|
||||
}
|
||||
|
||||
func openMigrationSQLDB(t *testing.T, driver, dsn string) *sql.DB {
|
||||
t.Helper()
|
||||
|
||||
db, err := sql.Open(driver, dsn)
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, db.Ping())
|
||||
return db
|
||||
}
|
||||
|
||||
func seedLegacyMigrationData(ctx context.Context, t *testing.T, driver string, db *sql.DB) {
|
||||
t.Helper()
|
||||
|
||||
execMigrationSQL(t, db, legacyInsertUserSQL(driver, 11, "owner"))
|
||||
execMigrationSQL(t, db, legacyInsertUserSQL(driver, 12, "commenter"))
|
||||
execMigrationSQL(t, db, legacyInsertMemoSQL(101, 11, "legacy-parent", "parent memo"))
|
||||
execMigrationSQL(t, db, legacyInsertMemoSQL(102, 12, "legacy-comment", "comment memo"))
|
||||
execMigrationSQL(t, db, legacyInsertActivitySQL(201, 12))
|
||||
execMigrationSQL(t, db, legacyInsertInboxSQL(301, 12, 11, 201))
|
||||
execMigrationSQL(t, db, legacyInsertIDPSQL(401, "Legacy Google"))
|
||||
execMigrationSQL(t, db, legacyInsertIDPSQL(402, "Legacy GitHub"))
|
||||
|
||||
var message string
|
||||
err := db.QueryRowContext(ctx, "SELECT message FROM inbox WHERE id = 301").Scan(&message)
|
||||
require.NoError(t, err)
|
||||
require.Contains(t, message, "\"activityId\":201")
|
||||
require.NotContains(t, message, "\"memoComment\"")
|
||||
}
|
||||
|
||||
func execMigrationSQL(t *testing.T, db *sql.DB, query string) {
|
||||
t.Helper()
|
||||
_, err := db.Exec(query)
|
||||
require.NoError(t, err, "failed to execute SQL: %s", query)
|
||||
}
|
||||
|
||||
func legacyInsertUserSQL(driver string, id int, username string) string {
|
||||
table := "user"
|
||||
switch driver {
|
||||
case "mysql":
|
||||
table = "`user`"
|
||||
case "postgres", "sqlite":
|
||||
table = `"user"`
|
||||
default:
|
||||
// Keep the unquoted fallback for unknown test drivers.
|
||||
}
|
||||
|
||||
return fmt.Sprintf(
|
||||
"INSERT INTO %s (id, username, role, email, nickname, password_hash, avatar_url, description) VALUES (%d, '%s', 'USER', '%s@example.com', '%s', 'legacy-hash', '', 'legacy user')",
|
||||
table, id, username, username, username,
|
||||
)
|
||||
}
|
||||
|
||||
func legacyInsertMemoSQL(id, creatorID int, uid, content string) string {
|
||||
payload := "{}"
|
||||
return fmt.Sprintf(
|
||||
"INSERT INTO memo (id, uid, creator_id, content, visibility, payload) VALUES (%d, '%s', %d, '%s', 'PRIVATE', '%s')",
|
||||
id, uid, creatorID, content, payload,
|
||||
)
|
||||
}
|
||||
|
||||
func legacyInsertActivitySQL(id, creatorID int) string {
|
||||
payload := `{"memoComment":{"memoId":102,"relatedMemoId":101}}`
|
||||
return fmt.Sprintf(
|
||||
"INSERT INTO activity (id, creator_id, type, level, payload) VALUES (%d, %d, 'MEMO_COMMENT', 'INFO', '%s')",
|
||||
id, creatorID, payload,
|
||||
)
|
||||
}
|
||||
|
||||
func legacyInsertInboxSQL(id, senderID, receiverID, activityID int) string {
|
||||
message := fmt.Sprintf(`{"type":"MEMO_COMMENT","activityId":%d}`, activityID)
|
||||
return fmt.Sprintf(
|
||||
"INSERT INTO inbox (id, sender_id, receiver_id, status, message) VALUES (%d, %d, %d, 'UNREAD', '%s')",
|
||||
id, senderID, receiverID, message,
|
||||
)
|
||||
}
|
||||
|
||||
func legacyInsertIDPSQL(id int, name string) string {
|
||||
config := `{"clientId":"legacy-client","clientSecret":"legacy-secret","authUrl":"https://example.com/auth","tokenUrl":"https://example.com/token","userInfoUrl":"https://example.com/userinfo"}`
|
||||
return fmt.Sprintf(
|
||||
"INSERT INTO idp (id, name, type, identifier_filter, config) VALUES (%d, '%s', 'OAUTH2', '', '%s')",
|
||||
id, name, config,
|
||||
)
|
||||
}
|
||||
|
||||
func countSystemSetting(ctx context.Context, db *sql.DB, name string) (int, error) {
|
||||
var count int
|
||||
err := db.QueryRowContext(ctx, "SELECT COUNT(*) FROM system_setting WHERE name = ?", name).Scan(&count)
|
||||
if err == nil {
|
||||
return count, nil
|
||||
}
|
||||
|
||||
err = db.QueryRowContext(ctx, "SELECT COUNT(*) FROM system_setting WHERE name = $1", name).Scan(&count)
|
||||
return count, err
|
||||
}
|
||||
|
||||
func tableExists(ctx context.Context, db *sql.DB, driver, table string) (bool, error) {
|
||||
switch driver {
|
||||
case "sqlite":
|
||||
var name string
|
||||
err := db.QueryRowContext(ctx, "SELECT name FROM sqlite_master WHERE type = 'table' AND name = ?", table).Scan(&name)
|
||||
if err == sql.ErrNoRows {
|
||||
return false, nil
|
||||
}
|
||||
return err == nil, err
|
||||
case "mysql":
|
||||
var count int
|
||||
err := db.QueryRowContext(ctx, "SELECT COUNT(*) FROM information_schema.tables WHERE table_schema = DATABASE() AND table_name = ?", table).Scan(&count)
|
||||
return count > 0, err
|
||||
case "postgres":
|
||||
var regclass sql.NullString
|
||||
err := db.QueryRowContext(ctx, "SELECT to_regclass($1)", "public."+table).Scan(®class)
|
||||
return regclass.Valid && strings.EqualFold(regclass.String, table), err
|
||||
default:
|
||||
return false, errors.Errorf("unsupported driver: %s", driver)
|
||||
}
|
||||
}
|
||||
|
|
@ -91,14 +91,6 @@ func TestUserGetByID(t *testing.T) {
|
|||
require.NoError(t, err)
|
||||
require.Nil(t, notFound)
|
||||
|
||||
// Get system bot
|
||||
systemBotID := store.SystemBotID
|
||||
systemBot, err := ts.GetUser(ctx, &store.FindUser{ID: &systemBotID})
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, systemBot)
|
||||
require.Equal(t, store.SystemBotID, systemBot.ID)
|
||||
require.Equal(t, "system_bot", systemBot.Username)
|
||||
|
||||
ts.Close()
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -23,20 +23,6 @@ func (e Role) String() string {
|
|||
}
|
||||
}
|
||||
|
||||
const (
|
||||
SystemBotID int32 = 0
|
||||
)
|
||||
|
||||
var (
|
||||
SystemBot = &User{
|
||||
ID: SystemBotID,
|
||||
Username: "system_bot",
|
||||
Role: RoleAdmin,
|
||||
Email: "",
|
||||
Nickname: "Bot",
|
||||
}
|
||||
)
|
||||
|
||||
type User struct {
|
||||
ID int32
|
||||
|
||||
|
|
@ -125,9 +111,6 @@ func (s *Store) ListUsers(ctx context.Context, find *FindUser) ([]*User, error)
|
|||
|
||||
func (s *Store) GetUser(ctx context.Context, find *FindUser) (*User, error) {
|
||||
if find.ID != nil {
|
||||
if *find.ID == SystemBotID {
|
||||
return SystemBot, nil
|
||||
}
|
||||
if cache, ok := s.userCache.Get(ctx, string(*find.ID)); ok {
|
||||
user, ok := cache.(*User)
|
||||
if ok {
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@ import { useEffect, useMemo, useState } from "react";
|
|||
import useDebounce from "react-use/lib/useDebounce";
|
||||
import { memoServiceClient } from "@/connect";
|
||||
import { DEFAULT_LIST_MEMOS_PAGE_SIZE } from "@/helpers/consts";
|
||||
import { extractUserIdFromName } from "@/helpers/resource-names";
|
||||
import { buildMemoCreatorFilter } from "@/helpers/resource-names";
|
||||
import useCurrentUser from "@/hooks/useCurrentUser";
|
||||
import {
|
||||
type Memo,
|
||||
|
|
@ -44,7 +44,11 @@ export const useLinkMemo = ({ isOpen, currentMemoName, existingRelations, onAddR
|
|||
|
||||
setIsFetching(true);
|
||||
try {
|
||||
const conditions = [`creator_id == ${extractUserIdFromName(user?.name ?? "")}`];
|
||||
const conditions: string[] = [];
|
||||
const creatorFilter = buildMemoCreatorFilter(user?.name ?? "");
|
||||
if (creatorFilter) {
|
||||
conditions.push(creatorFilter);
|
||||
}
|
||||
if (searchText) {
|
||||
conditions.push(`content.contains("${searchText}")`);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,25 +1,49 @@
|
|||
import { debounce } from "lodash-es";
|
||||
|
||||
export const CACHE_DEBOUNCE_DELAY = 500;
|
||||
|
||||
const pendingSaves = new Map<string, ReturnType<typeof window.setTimeout>>();
|
||||
|
||||
export const cacheService = {
|
||||
key: (username: string, cacheKey?: string): string => {
|
||||
return `${username}-${cacheKey || ""}`;
|
||||
},
|
||||
|
||||
save: debounce((key: string, content: string) => {
|
||||
if (content.trim()) {
|
||||
localStorage.setItem(key, content);
|
||||
} else {
|
||||
localStorage.removeItem(key);
|
||||
save: (key: string, content: string) => {
|
||||
const pendingSave = pendingSaves.get(key);
|
||||
if (pendingSave) {
|
||||
window.clearTimeout(pendingSave);
|
||||
}
|
||||
}, CACHE_DEBOUNCE_DELAY),
|
||||
|
||||
const timeoutId = window.setTimeout(() => {
|
||||
pendingSaves.delete(key);
|
||||
|
||||
if (content.trim()) {
|
||||
localStorage.setItem(key, content);
|
||||
} else {
|
||||
localStorage.removeItem(key);
|
||||
}
|
||||
}, CACHE_DEBOUNCE_DELAY);
|
||||
|
||||
pendingSaves.set(key, timeoutId);
|
||||
},
|
||||
|
||||
load(key: string): string {
|
||||
return localStorage.getItem(key) || "";
|
||||
},
|
||||
|
||||
clear(key: string): void {
|
||||
const pendingSave = pendingSaves.get(key);
|
||||
if (pendingSave) {
|
||||
window.clearTimeout(pendingSave);
|
||||
pendingSaves.delete(key);
|
||||
}
|
||||
|
||||
localStorage.removeItem(key);
|
||||
},
|
||||
|
||||
clearAll(): void {
|
||||
for (const timeoutId of pendingSaves.values()) {
|
||||
window.clearTimeout(timeoutId);
|
||||
}
|
||||
pendingSaves.clear();
|
||||
},
|
||||
};
|
||||
|
|
|
|||
|
|
@ -15,7 +15,7 @@ import { DropdownMenu, DropdownMenuContent, DropdownMenuItem, DropdownMenuTrigge
|
|||
const emojiRegex = /^(\p{Emoji_Presentation}|\p{Emoji}\uFE0F)$/u;
|
||||
|
||||
// Helper function to extract shortcut ID from resource name
|
||||
// Format: users/{user}/shortcuts/{shortcut}
|
||||
// Format: users/{username}/shortcuts/{shortcut}
|
||||
const getShortcutId = (name: string): string => {
|
||||
const parts = name.split("/");
|
||||
return parts.length === 4 ? parts[3] : "";
|
||||
|
|
|
|||
|
|
@ -1,42 +1,20 @@
|
|||
import { FileAudioIcon, FileIcon, PaperclipIcon } from "lucide-react";
|
||||
import { FileIcon, PaperclipIcon } from "lucide-react";
|
||||
import { useMemo } from "react";
|
||||
import { cn } from "@/lib/utils";
|
||||
import type { Attachment } from "@/types/proto/api/v1/attachment_service_pb";
|
||||
import { getAttachmentType, getAttachmentUrl } from "@/utils/attachment";
|
||||
import { formatFileSize, getFileTypeLabel } from "@/utils/format";
|
||||
import { getAttachmentUrl } from "@/utils/attachment";
|
||||
import SectionHeader from "../SectionHeader";
|
||||
import AttachmentCard from "./AttachmentCard";
|
||||
import AudioAttachmentItem from "./AudioAttachmentItem";
|
||||
import { getAttachmentMetadata, isImageAttachment, separateAttachments } from "./attachmentViewHelpers";
|
||||
|
||||
interface AttachmentListViewProps {
|
||||
attachments: Attachment[];
|
||||
onImagePreview?: (urls: string[], index: number) => void;
|
||||
}
|
||||
|
||||
const isImageAttachment = (attachment: Attachment): boolean => getAttachmentType(attachment) === "image/*";
|
||||
const isVideoAttachment = (attachment: Attachment): boolean => getAttachmentType(attachment) === "video/*";
|
||||
const isAudioAttachment = (attachment: Attachment): boolean => getAttachmentType(attachment) === "audio/*";
|
||||
|
||||
const separateAttachments = (attachments: Attachment[]) => {
|
||||
const visual: Attachment[] = [];
|
||||
const audio: Attachment[] = [];
|
||||
const docs: Attachment[] = [];
|
||||
|
||||
for (const attachment of attachments) {
|
||||
if (isImageAttachment(attachment) || isVideoAttachment(attachment)) {
|
||||
visual.push(attachment);
|
||||
} else if (isAudioAttachment(attachment)) {
|
||||
audio.push(attachment);
|
||||
} else {
|
||||
docs.push(attachment);
|
||||
}
|
||||
}
|
||||
|
||||
return { visual, audio, docs };
|
||||
};
|
||||
|
||||
const DocumentItem = ({ attachment }: { attachment: Attachment }) => {
|
||||
const fileTypeLabel = getFileTypeLabel(attachment.type);
|
||||
const fileSizeLabel = attachment.size ? formatFileSize(Number(attachment.size)) : undefined;
|
||||
const { fileTypeLabel, fileSizeLabel } = getAttachmentMetadata(attachment);
|
||||
|
||||
return (
|
||||
<div className="flex items-center gap-1 px-1 py-1 rounded text-xs text-muted-foreground hover:text-foreground hover:bg-accent/20 transition-colors whitespace-nowrap">
|
||||
|
|
@ -62,22 +40,6 @@ const DocumentItem = ({ attachment }: { attachment: Attachment }) => {
|
|||
);
|
||||
};
|
||||
|
||||
const AudioItem = ({ attachment }: { attachment: Attachment }) => {
|
||||
const sourceUrl = getAttachmentUrl(attachment);
|
||||
|
||||
return (
|
||||
<div className="flex flex-col gap-1 px-1 py-1">
|
||||
<div className="flex items-center gap-1 text-xs text-muted-foreground">
|
||||
<FileAudioIcon className="w-3 h-3 shrink-0" />
|
||||
<span className="truncate" title={attachment.filename}>
|
||||
{attachment.filename}
|
||||
</span>
|
||||
</div>
|
||||
<audio src={sourceUrl} controls preload="metadata" className="w-full h-8" />
|
||||
</div>
|
||||
);
|
||||
};
|
||||
|
||||
interface VisualItemProps {
|
||||
attachment: Attachment;
|
||||
onImageClick?: (url: string) => void;
|
||||
|
|
@ -114,9 +76,9 @@ const VisualGrid = ({ attachments, onImageClick }: { attachments: Attachment[];
|
|||
);
|
||||
|
||||
const AudioList = ({ attachments }: { attachments: Attachment[] }) => (
|
||||
<div className="flex flex-col gap-1">
|
||||
<div className="flex flex-col gap-2">
|
||||
{attachments.map((attachment) => (
|
||||
<AudioItem key={attachment.name} attachment={attachment} />
|
||||
<AudioAttachmentItem key={attachment.name} attachment={attachment} />
|
||||
))}
|
||||
</div>
|
||||
);
|
||||
|
|
|
|||
|
|
@ -0,0 +1,178 @@
|
|||
import { FileAudioIcon, PauseIcon, PlayIcon } from "lucide-react";
|
||||
import { useEffect, useRef, useState } from "react";
|
||||
import type { Attachment } from "@/types/proto/api/v1/attachment_service_pb";
|
||||
import { getAttachmentUrl } from "@/utils/attachment";
|
||||
import { formatAudioTime, getAttachmentMetadata } from "./attachmentViewHelpers";
|
||||
|
||||
const AUDIO_PLAYBACK_RATES = [1, 1.5, 2] as const;
|
||||
|
||||
interface AudioProgressBarProps {
|
||||
attachment: Attachment;
|
||||
currentTime: number;
|
||||
duration: number;
|
||||
progressPercent: number;
|
||||
onSeek: (value: string) => void;
|
||||
}
|
||||
|
||||
const AudioProgressBar = ({ attachment, currentTime, duration, progressPercent, onSeek }: AudioProgressBarProps) => (
|
||||
<div className="mt-2 flex items-center gap-2.5">
|
||||
<div className="relative flex h-4 min-w-0 flex-1 items-center">
|
||||
<div className="absolute inset-x-0 h-1 rounded-full bg-muted/75" />
|
||||
<div className="absolute left-0 h-1 rounded-full bg-foreground/20" style={{ width: `${Math.min(progressPercent, 100)}%` }} />
|
||||
<input
|
||||
type="range"
|
||||
min={0}
|
||||
max={duration || 1}
|
||||
step={0.1}
|
||||
value={Math.min(currentTime, duration || 0)}
|
||||
onChange={(e) => onSeek(e.target.value)}
|
||||
aria-label={`Seek ${attachment.filename}`}
|
||||
className="relative z-10 h-4 w-full cursor-pointer appearance-none bg-transparent outline-none disabled:cursor-default
|
||||
[&::-webkit-slider-runnable-track]:h-1 [&::-webkit-slider-runnable-track]:rounded-full
|
||||
[&::-webkit-slider-runnable-track]:bg-transparent
|
||||
[&::-webkit-slider-thumb]:mt-[-3px] [&::-webkit-slider-thumb]:size-2 [&::-webkit-slider-thumb]:appearance-none
|
||||
[&::-webkit-slider-thumb]:rounded-full [&::-webkit-slider-thumb]:border [&::-webkit-slider-thumb]:border-border/50
|
||||
[&::-webkit-slider-thumb]:bg-background/95
|
||||
[&::-moz-range-track]:h-1 [&::-moz-range-track]:rounded-full [&::-moz-range-track]:bg-transparent
|
||||
[&::-moz-range-thumb]:size-2 [&::-moz-range-thumb]:rounded-full [&::-moz-range-thumb]:border
|
||||
[&::-moz-range-thumb]:border-border/50 [&::-moz-range-thumb]:bg-background/95"
|
||||
disabled={duration === 0}
|
||||
/>
|
||||
</div>
|
||||
<div className="shrink-0 text-[11px] tabular-nums text-muted-foreground">
|
||||
{formatAudioTime(currentTime)} / {duration > 0 ? formatAudioTime(duration) : "--:--"}
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
|
||||
const AudioAttachmentItem = ({ attachment }: { attachment: Attachment }) => {
|
||||
const sourceUrl = getAttachmentUrl(attachment);
|
||||
const audioRef = useRef<HTMLAudioElement>(null);
|
||||
const [isPlaying, setIsPlaying] = useState(false);
|
||||
const [currentTime, setCurrentTime] = useState(0);
|
||||
const [duration, setDuration] = useState(0);
|
||||
const [playbackRate, setPlaybackRate] = useState<(typeof AUDIO_PLAYBACK_RATES)[number]>(1);
|
||||
const { fileTypeLabel, fileSizeLabel } = getAttachmentMetadata(attachment);
|
||||
const progressPercent = duration > 0 ? (currentTime / duration) * 100 : 0;
|
||||
|
||||
useEffect(() => {
|
||||
if (!audioRef.current) {
|
||||
return;
|
||||
}
|
||||
|
||||
audioRef.current.playbackRate = playbackRate;
|
||||
}, [playbackRate]);
|
||||
|
||||
const togglePlayback = async () => {
|
||||
const audio = audioRef.current;
|
||||
|
||||
if (!audio) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (audio.paused) {
|
||||
try {
|
||||
await audio.play();
|
||||
} catch {
|
||||
setIsPlaying(false);
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
audio.pause();
|
||||
};
|
||||
|
||||
const handleSeek = (value: string) => {
|
||||
const audio = audioRef.current;
|
||||
const nextTime = Number(value);
|
||||
|
||||
if (!audio || Number.isNaN(nextTime)) {
|
||||
return;
|
||||
}
|
||||
|
||||
audio.currentTime = nextTime;
|
||||
setCurrentTime(nextTime);
|
||||
};
|
||||
|
||||
const handlePlaybackRateChange = () => {
|
||||
const currentRateIndex = AUDIO_PLAYBACK_RATES.findIndex((rate) => rate === playbackRate);
|
||||
const nextRate = AUDIO_PLAYBACK_RATES[(currentRateIndex + 1) % AUDIO_PLAYBACK_RATES.length];
|
||||
setPlaybackRate(nextRate);
|
||||
};
|
||||
|
||||
const handleDuration = (value: number) => {
|
||||
setDuration(Number.isFinite(value) ? value : 0);
|
||||
};
|
||||
|
||||
return (
|
||||
<div className="rounded-xl border border-border/35 bg-background/70 px-2.5 py-2.5">
|
||||
<div className="flex items-start gap-2.5">
|
||||
<div className="mt-0.5 flex size-8 shrink-0 items-center justify-center rounded-lg bg-muted/55 text-muted-foreground">
|
||||
<FileAudioIcon className="size-3.5" />
|
||||
</div>
|
||||
|
||||
<div className="flex min-w-0 flex-1 items-start justify-between gap-3">
|
||||
<div className="min-w-0 flex-1">
|
||||
<div className="truncate text-sm font-medium leading-5 text-foreground" title={attachment.filename}>
|
||||
{attachment.filename}
|
||||
</div>
|
||||
<div className="flex flex-wrap items-center gap-x-1.5 gap-y-0.5 text-xs leading-4 text-muted-foreground">
|
||||
<span>{fileTypeLabel}</span>
|
||||
{fileSizeLabel && (
|
||||
<>
|
||||
<span className="text-muted-foreground/50">•</span>
|
||||
<span>{fileSizeLabel}</span>
|
||||
</>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div className="mt-0.5 flex shrink-0 items-center gap-1">
|
||||
<button
|
||||
type="button"
|
||||
onClick={handlePlaybackRateChange}
|
||||
className="inline-flex h-6 items-center justify-center px-1 text-[11px] font-medium text-muted-foreground transition-colors hover:text-foreground"
|
||||
aria-label={`Playback speed ${playbackRate}x for ${attachment.filename}`}
|
||||
>
|
||||
{playbackRate}x
|
||||
</button>
|
||||
<button
|
||||
type="button"
|
||||
onClick={togglePlayback}
|
||||
className="inline-flex size-6.5 items-center justify-center rounded-md border border-border/45 bg-background/85 text-foreground transition-colors hover:bg-muted/45"
|
||||
aria-label={isPlaying ? `Pause ${attachment.filename}` : `Play ${attachment.filename}`}
|
||||
>
|
||||
{isPlaying ? <PauseIcon className="size-3" /> : <PlayIcon className="size-3 translate-x-[0.5px]" />}
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<AudioProgressBar
|
||||
attachment={attachment}
|
||||
currentTime={currentTime}
|
||||
duration={duration}
|
||||
progressPercent={progressPercent}
|
||||
onSeek={handleSeek}
|
||||
/>
|
||||
|
||||
<audio
|
||||
ref={audioRef}
|
||||
src={sourceUrl}
|
||||
preload="metadata"
|
||||
className="hidden"
|
||||
onLoadedMetadata={(e) => handleDuration(e.currentTarget.duration)}
|
||||
onDurationChange={(e) => handleDuration(e.currentTarget.duration)}
|
||||
onTimeUpdate={(e) => setCurrentTime(e.currentTarget.currentTime)}
|
||||
onPlay={() => setIsPlaying(true)}
|
||||
onPause={() => setIsPlaying(false)}
|
||||
onEnded={() => {
|
||||
setIsPlaying(false);
|
||||
setCurrentTime(0);
|
||||
}}
|
||||
/>
|
||||
</div>
|
||||
);
|
||||
};
|
||||
|
||||
export default AudioAttachmentItem;
|
||||
|
|
@ -0,0 +1,64 @@
|
|||
import type { Attachment } from "@/types/proto/api/v1/attachment_service_pb";
|
||||
import { getAttachmentType } from "@/utils/attachment";
|
||||
import { formatFileSize, getFileTypeLabel } from "@/utils/format";
|
||||
|
||||
export interface AttachmentGroups {
|
||||
visual: Attachment[];
|
||||
audio: Attachment[];
|
||||
docs: Attachment[];
|
||||
}
|
||||
|
||||
export interface AttachmentMetadata {
|
||||
fileTypeLabel: string;
|
||||
fileSizeLabel?: string;
|
||||
}
|
||||
|
||||
export const isImageAttachment = (attachment: Attachment): boolean => getAttachmentType(attachment) === "image/*";
|
||||
export const isVideoAttachment = (attachment: Attachment): boolean => getAttachmentType(attachment) === "video/*";
|
||||
export const isAudioAttachment = (attachment: Attachment): boolean => getAttachmentType(attachment) === "audio/*";
|
||||
|
||||
export const separateAttachments = (attachments: Attachment[]): AttachmentGroups => {
|
||||
const groups: AttachmentGroups = {
|
||||
visual: [],
|
||||
audio: [],
|
||||
docs: [],
|
||||
};
|
||||
|
||||
for (const attachment of attachments) {
|
||||
if (isImageAttachment(attachment) || isVideoAttachment(attachment)) {
|
||||
groups.visual.push(attachment);
|
||||
continue;
|
||||
}
|
||||
|
||||
if (isAudioAttachment(attachment)) {
|
||||
groups.audio.push(attachment);
|
||||
continue;
|
||||
}
|
||||
|
||||
groups.docs.push(attachment);
|
||||
}
|
||||
|
||||
return groups;
|
||||
};
|
||||
|
||||
export const getAttachmentMetadata = (attachment: Attachment): AttachmentMetadata => ({
|
||||
fileTypeLabel: getFileTypeLabel(attachment.type),
|
||||
fileSizeLabel: attachment.size ? formatFileSize(Number(attachment.size)) : undefined,
|
||||
});
|
||||
|
||||
export const formatAudioTime = (seconds: number): string => {
|
||||
if (!Number.isFinite(seconds) || seconds < 0) {
|
||||
return "0:00";
|
||||
}
|
||||
|
||||
const rounded = Math.floor(seconds);
|
||||
const hours = Math.floor(rounded / 3600);
|
||||
const minutes = Math.floor((rounded % 3600) / 60);
|
||||
const secs = rounded % 60;
|
||||
|
||||
if (hours > 0) {
|
||||
return `${hours}:${minutes.toString().padStart(2, "0")}:${secs.toString().padStart(2, "0")}`;
|
||||
}
|
||||
|
||||
return `${minutes}:${secs.toString().padStart(2, "0")}`;
|
||||
};
|
||||
|
|
@ -115,3 +115,123 @@ const MemoView: React.FC<MemoViewProps> = (props: MemoViewProps) => {
|
|||
};
|
||||
|
||||
export default memo(MemoView);
|
||||
import { memo, useCallback, useMemo, useRef, useState } from "react";
|
||||
import { useLocation } from "react-router-dom";
|
||||
import { useInstance } from "@/contexts/InstanceContext";
|
||||
import useCurrentUser from "@/hooks/useCurrentUser";
|
||||
import { useUser } from "@/hooks/useUserQueries";
|
||||
import { findTagMetadata } from "@/lib/tag";
|
||||
import { cn } from "@/lib/utils";
|
||||
import { State } from "@/types/proto/api/v1/common_pb";
|
||||
import { isSuperUser } from "@/utils/user";
|
||||
import MemoEditor from "../MemoEditor";
|
||||
import PreviewImageDialog from "../PreviewImageDialog";
|
||||
import { MemoBody, MemoCommentListView, MemoHeader } from "./components";
|
||||
import { MEMO_CARD_BASE_CLASSES } from "./constants";
|
||||
import { useImagePreview } from "./hooks";
|
||||
import { computeCommentAmount, MemoViewContext } from "./MemoViewContext";
|
||||
import type { MemoViewProps } from "./types";
|
||||
|
||||
const MemoView: React.FC<MemoViewProps> = (props: MemoViewProps) => {
|
||||
const { memo: memoData, className, parentPage: parentPageProp, compact, showCreator, showVisibility, showPinned } = props;
|
||||
const cardRef = useRef<HTMLDivElement>(null);
|
||||
const [showEditor, setShowEditor] = useState(false);
|
||||
|
||||
const currentUser = useCurrentUser();
|
||||
const { tagsSetting } = useInstance();
|
||||
const creator = useUser(memoData.creator).data;
|
||||
const isArchived = memoData.state === State.ARCHIVED;
|
||||
const readonly = memoData.creator !== currentUser?.name && !isSuperUser(currentUser);
|
||||
const parentPage = parentPageProp || "/";
|
||||
|
||||
// Blur content when any tag has blur_content enabled in the instance tag settings.
|
||||
const [showBlurredContent, setShowBlurredContent] = useState(false);
|
||||
const blurred = memoData.tags?.some((tag) => findTagMetadata(tag, tagsSetting)?.blurContent) ?? false;
|
||||
const toggleBlurVisibility = useCallback(() => setShowBlurredContent((prev) => !prev), []);
|
||||
|
||||
const { previewState, openPreview, setPreviewOpen } = useImagePreview();
|
||||
|
||||
const openEditor = useCallback(() => setShowEditor(true), []);
|
||||
const closeEditor = useCallback(() => setShowEditor(false), []);
|
||||
|
||||
const location = useLocation();
|
||||
const isInMemoDetailPage = location.pathname.startsWith(`/${memoData.name}`) || location.pathname.startsWith("/memos/shares/");
|
||||
const showCommentPreview = !isInMemoDetailPage && computeCommentAmount(memoData) > 0;
|
||||
|
||||
const contextValue = useMemo(
|
||||
() => ({
|
||||
memo: memoData,
|
||||
creator,
|
||||
currentUser,
|
||||
parentPage,
|
||||
isArchived,
|
||||
readonly,
|
||||
showBlurredContent,
|
||||
blurred,
|
||||
openEditor,
|
||||
toggleBlurVisibility,
|
||||
openPreview,
|
||||
}),
|
||||
[
|
||||
memoData,
|
||||
creator,
|
||||
currentUser,
|
||||
parentPage,
|
||||
isArchived,
|
||||
readonly,
|
||||
showBlurredContent,
|
||||
blurred,
|
||||
openEditor,
|
||||
toggleBlurVisibility,
|
||||
openPreview,
|
||||
],
|
||||
);
|
||||
|
||||
if (showEditor) {
|
||||
return (
|
||||
<MemoEditor
|
||||
autoFocus
|
||||
className="mb-2"
|
||||
cacheKey={`inline-memo-editor-${memoData.name}`}
|
||||
memo={memoData}
|
||||
parentMemoName={memoData.parent || undefined}
|
||||
onConfirm={closeEditor}
|
||||
onCancel={closeEditor}
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
||||
const article = (
|
||||
<article
|
||||
className={cn(MEMO_CARD_BASE_CLASSES, showCommentPreview ? "mb-0 rounded-b-none" : "mb-2", className)}
|
||||
ref={cardRef}
|
||||
tabIndex={readonly ? -1 : 0}
|
||||
>
|
||||
<MemoHeader showCreator={showCreator} showVisibility={showVisibility} showPinned={showPinned} />
|
||||
|
||||
<MemoBody compact={compact} />
|
||||
|
||||
<PreviewImageDialog
|
||||
open={previewState.open}
|
||||
onOpenChange={setPreviewOpen}
|
||||
imgUrls={previewState.urls}
|
||||
initialIndex={previewState.index}
|
||||
/>
|
||||
</article>
|
||||
);
|
||||
|
||||
return (
|
||||
<MemoViewContext.Provider value={contextValue}>
|
||||
{showCommentPreview ? (
|
||||
<div className="w-full mb-2">
|
||||
{article}
|
||||
<MemoCommentListView />
|
||||
</div>
|
||||
) : (
|
||||
article
|
||||
)}
|
||||
</MemoViewContext.Provider>
|
||||
);
|
||||
};
|
||||
|
||||
export default memo(MemoView);
|
||||
|
|
|
|||
|
|
@ -22,8 +22,7 @@ import SettingGroup from "./SettingGroup";
|
|||
import SettingSection from "./SettingSection";
|
||||
import SettingTable from "./SettingTable";
|
||||
|
||||
// Fallback to white when no color is stored.
|
||||
const tagColorToHex = (color?: { red?: number; green?: number; blue?: number }): string => colorToHex(color) ?? "#ffffff";
|
||||
const DEFAULT_TAG_COLOR = "#ffffff";
|
||||
|
||||
// Converts a CSS hex string to a google.type.Color message.
|
||||
const hexToColor = (hex: string) =>
|
||||
|
|
@ -33,24 +32,36 @@ const hexToColor = (hex: string) =>
|
|||
blue: parseInt(hex.slice(5, 7), 16) / 255,
|
||||
});
|
||||
|
||||
interface LocalTagMeta {
|
||||
color?: string;
|
||||
blur: boolean;
|
||||
}
|
||||
|
||||
const toLocalTagMeta = (meta: {
|
||||
backgroundColor?: { red?: number; green?: number; blue?: number };
|
||||
blurContent: boolean;
|
||||
}): LocalTagMeta => ({
|
||||
color: colorToHex(meta.backgroundColor),
|
||||
blur: meta.blurContent,
|
||||
});
|
||||
|
||||
const TagsSection = () => {
|
||||
const t = useTranslate();
|
||||
const { tagsSetting: originalSetting, updateSetting, fetchSetting } = useInstance();
|
||||
const { data: tagCounts = {} } = useTagCounts(false);
|
||||
|
||||
// Local state: map of tagName → hex color string for editing.
|
||||
const [localTags, setLocalTags] = useState<Record<string, string>>(() =>
|
||||
Object.fromEntries(Object.entries(originalSetting.tags).map(([name, meta]) => [name, tagColorToHex(meta.backgroundColor)])),
|
||||
// Local state: map of tagName → { color, blur } for editing.
|
||||
const [localTags, setLocalTags] = useState<Record<string, LocalTagMeta>>(() =>
|
||||
Object.fromEntries(Object.entries(originalSetting.tags).map(([name, meta]) => [name, toLocalTagMeta(meta)])),
|
||||
);
|
||||
const [newTagName, setNewTagName] = useState("");
|
||||
const [newTagColor, setNewTagColor] = useState("#ffffff");
|
||||
const [newTagColor, setNewTagColor] = useState<string | undefined>(undefined);
|
||||
const [newTagBlur, setNewTagBlur] = useState(false);
|
||||
|
||||
// Sync local state when the fetched setting arrives (the fetch is async and
|
||||
// completes after mount, so localTags would be empty without this sync).
|
||||
useEffect(() => {
|
||||
setLocalTags(
|
||||
Object.fromEntries(Object.entries(originalSetting.tags).map(([name, meta]) => [name, tagColorToHex(meta.backgroundColor)])),
|
||||
);
|
||||
setLocalTags(Object.fromEntries(Object.entries(originalSetting.tags).map(([name, meta]) => [name, toLocalTagMeta(meta)])));
|
||||
}, [originalSetting.tags]);
|
||||
|
||||
// All known tag names: union of saved entries and tags used in memos.
|
||||
|
|
@ -68,8 +79,8 @@ const TagsSection = () => {
|
|||
[localTags],
|
||||
);
|
||||
|
||||
const originalHexMap = useMemo(
|
||||
() => Object.fromEntries(Object.entries(originalSetting.tags).map(([name, meta]) => [name, tagColorToHex(meta.backgroundColor)])),
|
||||
const originalMetaMap = useMemo(
|
||||
() => Object.fromEntries(Object.entries(originalSetting.tags).map(([name, meta]) => [name, toLocalTagMeta(meta)])),
|
||||
[originalSetting.tags],
|
||||
);
|
||||
const hasChanges = !isEqual(localTags, originalHexMap);
|
||||
|
|
@ -78,6 +89,10 @@ const TagsSection = () => {
|
|||
setLocalTags((prev) => ({ ...prev, [tagName]: hex }));
|
||||
};
|
||||
|
||||
const handleClearColor = (tagName: string) => {
|
||||
setLocalTags((prev) => ({ ...prev, [tagName]: { ...prev[tagName], color: undefined } }));
|
||||
};
|
||||
|
||||
const handleRemoveTag = (tagName: string) => {
|
||||
setLocalTags((prev) => {
|
||||
const next = { ...prev };
|
||||
|
|
@ -99,7 +114,8 @@ const TagsSection = () => {
|
|||
}
|
||||
setLocalTags((prev) => ({ ...prev, [name]: newTagColor }));
|
||||
setNewTagName("");
|
||||
setNewTagColor("#ffffff");
|
||||
setNewTagColor(undefined);
|
||||
setNewTagBlur(false);
|
||||
};
|
||||
|
||||
const handleSave = async () => {
|
||||
|
|
@ -107,7 +123,10 @@ const TagsSection = () => {
|
|||
const tags = Object.fromEntries(
|
||||
Object.entries(localTags).map(([name, hex]) => [
|
||||
name,
|
||||
create(InstanceSetting_TagMetadataSchema, { backgroundColor: hexToColor(hex) }),
|
||||
create(InstanceSetting_TagMetadataSchema, {
|
||||
blurContent: meta.blur,
|
||||
...(meta.color ? { backgroundColor: hexToColor(meta.color) } : {}),
|
||||
}),
|
||||
]),
|
||||
);
|
||||
await updateSetting(
|
||||
|
|
@ -144,9 +163,15 @@ const TagsSection = () => {
|
|||
<input
|
||||
type="color"
|
||||
className="w-8 h-8 cursor-pointer rounded border border-border bg-transparent p-0.5"
|
||||
value={localTags[row.name]}
|
||||
value={localTags[row.name].color ?? DEFAULT_TAG_COLOR}
|
||||
onChange={(e) => handleColorChange(row.name, e.target.value)}
|
||||
/>
|
||||
<Button variant="ghost" size="sm" onClick={() => handleClearColor(row.name)} disabled={!localTags[row.name].color}>
|
||||
{t("common.clear")}
|
||||
</Button>
|
||||
{!localTags[row.name].color && (
|
||||
<span className="text-xs text-muted-foreground">{t("setting.tags.using-default-color")}</span>
|
||||
)}
|
||||
</div>
|
||||
),
|
||||
},
|
||||
|
|
@ -185,15 +210,28 @@ const TagsSection = () => {
|
|||
<input
|
||||
type="color"
|
||||
className="w-8 h-8 cursor-pointer rounded border border-border bg-transparent p-0.5"
|
||||
value={newTagColor}
|
||||
value={newTagColor ?? DEFAULT_TAG_COLOR}
|
||||
onChange={(e) => setNewTagColor(e.target.value)}
|
||||
/>
|
||||
<Button variant="ghost" size="sm" onClick={() => setNewTagColor(undefined)} disabled={!newTagColor}>
|
||||
{t("common.clear")}
|
||||
</Button>
|
||||
<label className="flex items-center gap-1.5 text-sm text-muted-foreground">
|
||||
<input
|
||||
type="checkbox"
|
||||
className="w-4 h-4 cursor-pointer"
|
||||
checked={newTagBlur}
|
||||
onChange={(e) => setNewTagBlur(e.target.checked)}
|
||||
/>
|
||||
{t("setting.tags.blur-content")}
|
||||
</label>
|
||||
<Button variant="outline" onClick={handleAddTag} disabled={!newTagName.trim()}>
|
||||
<PlusIcon className="w-4 h-4 mr-1.5" />
|
||||
{t("common.add")}
|
||||
</Button>
|
||||
</div>
|
||||
<p className="text-xs text-muted-foreground mt-1">{t("setting.tags.tag-pattern-hint")}</p>
|
||||
{!newTagColor && <p className="text-xs text-muted-foreground">{t("setting.tags.using-default-color")}</p>}
|
||||
</SettingGroup>
|
||||
|
||||
<div className="w-full flex justify-end">
|
||||
|
|
|
|||
|
|
@ -8,6 +8,7 @@ import { MapContainer, Marker, Popup, useMap } from "react-leaflet";
|
|||
import MarkerClusterGroup from "react-leaflet-cluster";
|
||||
import { Link } from "react-router-dom";
|
||||
import { defaultMarkerIcon, ThemedTileLayer } from "@/components/map/map-utils";
|
||||
import { buildMemoCreatorFilter } from "@/helpers/resource-names";
|
||||
import { useInfiniteMemos } from "@/hooks/useMemoQueries";
|
||||
import { cn } from "@/lib/utils";
|
||||
import { State } from "@/types/proto/api/v1/common_pb";
|
||||
|
|
@ -30,11 +31,6 @@ const createClusterCustomIcon = (cluster: ClusterGroup) => {
|
|||
});
|
||||
};
|
||||
|
||||
const extractUserIdFromName = (name: string): string => {
|
||||
const match = name.match(/users\/(\d+)/);
|
||||
return match ? match[1] : "";
|
||||
};
|
||||
|
||||
const MapFitBounds = ({ memos }: { memos: Memo[] }) => {
|
||||
const map = useMap();
|
||||
|
||||
|
|
@ -52,14 +48,17 @@ const MapFitBounds = ({ memos }: { memos: Memo[] }) => {
|
|||
};
|
||||
|
||||
const UserMemoMap = ({ creator, className }: Props) => {
|
||||
const creatorId = useMemo(() => extractUserIdFromName(creator), [creator]);
|
||||
const creatorFilter = useMemo(() => buildMemoCreatorFilter(creator), [creator]);
|
||||
|
||||
const { data, isLoading } = useInfiniteMemos({
|
||||
state: State.NORMAL,
|
||||
orderBy: "display_time desc",
|
||||
pageSize: 1000,
|
||||
filter: `creator_id == ${creatorId}`,
|
||||
});
|
||||
const { data, isLoading } = useInfiniteMemos(
|
||||
{
|
||||
state: State.NORMAL,
|
||||
orderBy: "display_time desc",
|
||||
pageSize: 1000,
|
||||
filter: creatorFilter,
|
||||
},
|
||||
{ enabled: Boolean(creatorFilter) },
|
||||
);
|
||||
|
||||
const memosWithLocation = useMemo(() => data?.pages.flatMap((page) => page.memos).filter((memo) => memo.location) || [], [data]);
|
||||
|
||||
|
|
|
|||
|
|
@ -7,8 +7,12 @@ export const userNamePrefix = "users/";
|
|||
export const memoNamePrefix = "memos/";
|
||||
export const identityProviderNamePrefix = "identity-providers/";
|
||||
|
||||
export const extractUserIdFromName = (name: string) => {
|
||||
return name.split(userNamePrefix).pop() || "";
|
||||
export const buildMemoCreatorFilter = (name: string) => {
|
||||
if (!name) {
|
||||
return undefined;
|
||||
}
|
||||
const normalizedName = name.startsWith(userNamePrefix) ? name : `${userNamePrefix}${name}`;
|
||||
return `creator == ${JSON.stringify(normalizedName)}`;
|
||||
};
|
||||
|
||||
export const extractMemoIdFromName = (name: string) => {
|
||||
|
|
|
|||
|
|
@ -12,6 +12,15 @@ const INITIAL_RETRY_DELAY_MS = 1000;
|
|||
const MAX_RETRY_DELAY_MS = 30000;
|
||||
const RETRY_BACKOFF_MULTIPLIER = 2;
|
||||
|
||||
const SSE_EVENT_TYPES = {
|
||||
memoCreated: "memo.created",
|
||||
memoUpdated: "memo.updated",
|
||||
memoDeleted: "memo.deleted",
|
||||
memoCommentCreated: "memo.comment.created",
|
||||
reactionUpserted: "reaction.upserted",
|
||||
reactionDeleted: "reaction.deleted",
|
||||
} as const;
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Shared connection status store (singleton)
|
||||
// ---------------------------------------------------------------------------
|
||||
|
|
@ -63,6 +72,7 @@ export function useLiveMemoRefresh() {
|
|||
const { currentUser } = useAuth();
|
||||
const retryDelayRef = useRef(INITIAL_RETRY_DELAY_MS);
|
||||
const abortControllerRef = useRef<AbortController | null>(null);
|
||||
const hasConnectedOnceRef = useRef(false);
|
||||
|
||||
const currentUserName = currentUser?.name;
|
||||
const handleEvent = useCallback((event: SSEChangeEvent) => handleSSEEvent(event, queryClient), [queryClient]);
|
||||
|
|
@ -101,6 +111,13 @@ export function useLiveMemoRefresh() {
|
|||
// Successfully connected - reset retry delay.
|
||||
retryDelayRef.current = INITIAL_RETRY_DELAY_MS;
|
||||
setSSEStatus("connected");
|
||||
if (hasConnectedOnceRef.current) {
|
||||
// Resync active collaborative views after reconnect because the server may have
|
||||
// dropped events while the client was disconnected or backpressured.
|
||||
queryClient.invalidateQueries({ queryKey: memoKeys.all, refetchType: "active" });
|
||||
queryClient.invalidateQueries({ queryKey: userKeys.stats(), refetchType: "active" });
|
||||
}
|
||||
hasConnectedOnceRef.current = true;
|
||||
|
||||
const reader = response.body.getReader();
|
||||
const decoder = new TextDecoder();
|
||||
|
|
@ -175,37 +192,44 @@ export function useLiveMemoRefresh() {
|
|||
// ---------------------------------------------------------------------------
|
||||
|
||||
interface SSEChangeEvent {
|
||||
type: string;
|
||||
type: (typeof SSE_EVENT_TYPES)[keyof typeof SSE_EVENT_TYPES];
|
||||
name: string;
|
||||
parent?: string;
|
||||
}
|
||||
|
||||
function handleSSEEvent(event: SSEChangeEvent, queryClient: ReturnType<typeof useQueryClient>) {
|
||||
switch (event.type) {
|
||||
case "memo.created":
|
||||
case SSE_EVENT_TYPES.memoCreated:
|
||||
queryClient.invalidateQueries({ queryKey: memoKeys.lists() });
|
||||
queryClient.invalidateQueries({ queryKey: userKeys.stats() });
|
||||
break;
|
||||
|
||||
case "memo.updated":
|
||||
case SSE_EVENT_TYPES.memoUpdated:
|
||||
queryClient.invalidateQueries({ queryKey: memoKeys.detail(event.name) });
|
||||
queryClient.invalidateQueries({ queryKey: memoKeys.lists() });
|
||||
if (event.parent) {
|
||||
queryClient.invalidateQueries({ queryKey: memoKeys.comments(event.parent) });
|
||||
}
|
||||
break;
|
||||
|
||||
case "memo.deleted":
|
||||
case SSE_EVENT_TYPES.memoDeleted:
|
||||
queryClient.removeQueries({ queryKey: memoKeys.detail(event.name) });
|
||||
queryClient.invalidateQueries({ queryKey: memoKeys.lists() });
|
||||
queryClient.invalidateQueries({ queryKey: userKeys.stats() });
|
||||
break;
|
||||
|
||||
case "memo.comment.created":
|
||||
case SSE_EVENT_TYPES.memoCommentCreated:
|
||||
queryClient.invalidateQueries({ queryKey: memoKeys.comments(event.name) });
|
||||
queryClient.invalidateQueries({ queryKey: memoKeys.detail(event.name) });
|
||||
break;
|
||||
|
||||
case "reaction.upserted":
|
||||
case "reaction.deleted":
|
||||
case SSE_EVENT_TYPES.reactionUpserted:
|
||||
case SSE_EVENT_TYPES.reactionDeleted:
|
||||
queryClient.invalidateQueries({ queryKey: memoKeys.detail(event.name) });
|
||||
queryClient.invalidateQueries({ queryKey: memoKeys.lists() });
|
||||
if (event.parent) {
|
||||
queryClient.invalidateQueries({ queryKey: memoKeys.comments(event.parent) });
|
||||
}
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -2,13 +2,9 @@ import { useMemo } from "react";
|
|||
import { useAuth } from "@/contexts/AuthContext";
|
||||
import { useInstance } from "@/contexts/InstanceContext";
|
||||
import { useMemoFilterContext } from "@/contexts/MemoFilterContext";
|
||||
import { buildMemoCreatorFilter } from "@/helpers/resource-names";
|
||||
import { Visibility } from "@/types/proto/api/v1/memo_service_pb";
|
||||
|
||||
const extractUserIdFromName = (name: string): string => {
|
||||
const match = name.match(/users\/(\d+)/);
|
||||
return match ? match[1] : "";
|
||||
};
|
||||
|
||||
const getVisibilityName = (visibility: Visibility): string => {
|
||||
switch (visibility) {
|
||||
case Visibility.PUBLIC:
|
||||
|
|
@ -27,6 +23,8 @@ const getShortcutId = (name: string): string => {
|
|||
return parts.length === 4 ? parts[3] : "";
|
||||
};
|
||||
|
||||
const escapeFilterValue = (value: string): string => JSON.stringify(value);
|
||||
|
||||
export interface UseMemoFiltersOptions {
|
||||
creatorName?: string;
|
||||
includeShortcuts?: boolean;
|
||||
|
|
@ -53,7 +51,10 @@ export const useMemoFilters = (options: UseMemoFiltersOptions = {}): string | un
|
|||
|
||||
// Add creator filter if provided
|
||||
if (creatorName) {
|
||||
conditions.push(`creator_id == ${extractUserIdFromName(creatorName)}`);
|
||||
const creatorFilter = buildMemoCreatorFilter(creatorName);
|
||||
if (creatorFilter) {
|
||||
conditions.push(creatorFilter);
|
||||
}
|
||||
}
|
||||
|
||||
// Add shortcut filter if enabled and selected
|
||||
|
|
@ -64,9 +65,9 @@ export const useMemoFilters = (options: UseMemoFiltersOptions = {}): string | un
|
|||
// Add active filters from context
|
||||
for (const filter of filters) {
|
||||
if (filter.factor === "contentSearch") {
|
||||
conditions.push(`content.contains("${filter.value}")`);
|
||||
conditions.push(`content.contains(${escapeFilterValue(filter.value)})`);
|
||||
} else if (filter.factor === "tagSearch") {
|
||||
conditions.push(`tag in ["${filter.value}"]`);
|
||||
conditions.push(`tag in [${escapeFilterValue(filter.value)}]`);
|
||||
} else if (filter.factor === "pinned") {
|
||||
if (includePinned) {
|
||||
conditions.push(`pinned`);
|
||||
|
|
|
|||
|
|
@ -40,7 +40,7 @@ const MainLayout = () => {
|
|||
if (match && context === "profile") {
|
||||
const username = match.params.username;
|
||||
if (username) {
|
||||
// Fetch or get user to obtain user name (e.g., "users/123")
|
||||
// Fetch or get user to obtain the canonical user name (e.g., "users/steven")
|
||||
// Note: User stats will be fetched by useFilteredMemoStats
|
||||
userServiceClient
|
||||
.getUser({ name: `users/${username}` })
|
||||
|
|
|
|||
|
|
@ -474,14 +474,15 @@
|
|||
"tags": {
|
||||
"label": "Tags",
|
||||
"title": "Tag metadata",
|
||||
"description": "Assign display colors to tags instance-wide. Tag names are treated as anchored regex patterns.",
|
||||
"description": "Assign optional display colors to tags instance-wide, or blur matching memo content. Tag names are treated as anchored regex patterns.",
|
||||
"background-color": "Background color",
|
||||
"no-tags-configured": "No tag metadata configured.",
|
||||
"tag-name": "Tag name",
|
||||
"tag-name-placeholder": "e.g. work or project/.*",
|
||||
"tag-already-exists": "Tag already exists.",
|
||||
"tag-pattern-hint": "Tag name or regex pattern (e.g. project/.* matches all project/ tags)",
|
||||
"invalid-regex": "Invalid or unsafe regex pattern."
|
||||
"invalid-regex": "Invalid or unsafe regex pattern.",
|
||||
"using-default-color": "Using default color."
|
||||
}
|
||||
},
|
||||
"tag": {
|
||||
|
|
|
|||
|
|
@ -414,7 +414,8 @@ export const InstanceSetting_MemoRelatedSettingSchema: GenMessage<InstanceSettin
|
|||
*/
|
||||
export type InstanceSetting_TagMetadata = Message<"memos.api.v1.InstanceSetting.TagMetadata"> & {
|
||||
/**
|
||||
* Background color for the tag label.
|
||||
* Optional background color for the tag label.
|
||||
* When unset, the default tag color is used.
|
||||
*
|
||||
* @generated from field: google.type.Color background_color = 1;
|
||||
*/
|
||||
|
|
|
|||
|
|
@ -16,7 +16,7 @@ import type { Message } from "@bufbuild/protobuf";
|
|||
* Describes the file api/v1/shortcut_service.proto.
|
||||
*/
|
||||
export const file_api_v1_shortcut_service: GenFile = /*@__PURE__*/
|
||||
fileDesc("Ch1hcGkvdjEvc2hvcnRjdXRfc2VydmljZS5wcm90bxIMbWVtb3MuYXBpLnYxIpoBCghTaG9ydGN1dBIRCgRuYW1lGAEgASgJQgPgQQgSEgoFdGl0bGUYAiABKAlCA+BBAhITCgZmaWx0ZXIYAyABKAlCA+BBATpS6kFPChVtZW1vcy5hcGkudjEvU2hvcnRjdXQSIXVzZXJzL3t1c2VyfS9zaG9ydGN1dHMve3Nob3J0Y3V0fSoJc2hvcnRjdXRzMghzaG9ydGN1dCJFChRMaXN0U2hvcnRjdXRzUmVxdWVzdBItCgZwYXJlbnQYASABKAlCHeBBAvpBFxIVbWVtb3MuYXBpLnYxL1Nob3J0Y3V0IkIKFUxpc3RTaG9ydGN1dHNSZXNwb25zZRIpCglzaG9ydGN1dHMYASADKAsyFi5tZW1vcy5hcGkudjEuU2hvcnRjdXQiQQoSR2V0U2hvcnRjdXRSZXF1ZXN0EisKBG5hbWUYASABKAlCHeBBAvpBFwoVbWVtb3MuYXBpLnYxL1Nob3J0Y3V0IpEBChVDcmVhdGVTaG9ydGN1dFJlcXVlc3QSLQoGcGFyZW50GAEgASgJQh3gQQL6QRcSFW1lbW9zLmFwaS52MS9TaG9ydGN1dBItCghzaG9ydGN1dBgCIAEoCzIWLm1lbW9zLmFwaS52MS5TaG9ydGN1dEID4EECEhoKDXZhbGlkYXRlX29ubHkYAyABKAhCA+BBASJ8ChVVcGRhdGVTaG9ydGN1dFJlcXVlc3QSLQoIc2hvcnRjdXQYASABKAsyFi5tZW1vcy5hcGkudjEuU2hvcnRjdXRCA+BBAhI0Cgt1cGRhdGVfbWFzaxgCIAEoCzIaLmdvb2dsZS5wcm90b2J1Zi5GaWVsZE1hc2tCA+BBASJEChVEZWxldGVTaG9ydGN1dFJlcXVlc3QSKwoEbmFtZRgBIAEoCUId4EEC+kEXChVtZW1vcy5hcGkudjEvU2hvcnRjdXQy3gUKD1Nob3J0Y3V0U2VydmljZRKNAQoNTGlzdFNob3J0Y3V0cxIiLm1lbW9zLmFwaS52MS5MaXN0U2hvcnRjdXRzUmVxdWVzdBojLm1lbW9zLmFwaS52MS5MaXN0U2hvcnRjdXRzUmVzcG9uc2UiM9pBBnBhcmVudILT5JMCJBIiL2FwaS92MS97cGFyZW50PXVzZXJzLyp9L3Nob3J0Y3V0cxJ6CgtHZXRTaG9ydGN1dBIgLm1lbW9zLmFwaS52MS5HZXRTaG9ydGN1dFJlcXVlc3QaFi5tZW1vcy5hcGkudjEuU2hvcnRjdXQiMdpBBG5hbWWC0+STAiQSIi9hcGkvdjEve25hbWU9dXNlcnMvKi9zaG9ydGN1dHMvKn0SlQEKDkNyZWF0ZVNob3J0Y3V0EiMubWVtb3MuYXBpLnYxLkNyZWF0ZVNob3J0Y3V0UmVxdWVzdBoWLm1lbW9zLmFwaS52MS5TaG9ydGN1dCJG2kEPcGFyZW50LHNob3J0Y3V0gtPkkwIuOghzaG9ydGN1dCIiL2FwaS92MS97cGFyZW50PXVzZXJzLyp9L3Nob3J0Y3V0cxKjAQoOVXBkYXRlU2hvcnRjdXQSIy5tZW1vcy5hcGkudjEuVXBkYXRlU2hvcnRjdXRSZXF1ZXN0GhYubWVtb3MuYXBpLnYxLlNob3J0Y3V0IlTaQRRzaG9ydGN1dCx1cGRhdGVfbWFza4LT5JMCNzoIc2hvcnRjdXQyKy9hcGkvdjEve3Nob3J0Y3V0Lm5hbWU9dXNlcnMvKi9zaG9ydGN1dHMvKn0SgAEKDkRlbGV0ZVNob3J0Y3V0EiMubWVtb3MuYXBpLnYxLkRlbGV0ZVNob3J0Y3V0UmVxdWVzdBoWLmdvb2dsZS5wcm90b2J1Zi5FbXB0eSIx2kEEbmFtZYLT5JMCJCoiL2FwaS92MS97bmFtZT11c2Vycy8qL3Nob3J0Y3V0cy8qfUKsAQoQY29tLm1lbW9zLmFwaS52MUIUU2hvcnRjdXRTZXJ2aWNlUHJvdG9QAVowZ2l0aHViLmNvbS91c2VtZW1vcy9tZW1vcy9wcm90by9nZW4vYXBpL3YxO2FwaXYxogIDTUFYqgIMTWVtb3MuQXBpLlYxygIMTWVtb3NcQXBpXFYx4gIYTWVtb3NcQXBpXFYxXEdQQk1ldGFkYXRh6gIOTWVtb3M6OkFwaTo6VjFiBnByb3RvMw", [file_google_api_annotations, file_google_api_client, file_google_api_field_behavior, file_google_api_resource, file_google_protobuf_empty, file_google_protobuf_field_mask]);
|
||||
fileDesc("Ch1hcGkvdjEvc2hvcnRjdXRfc2VydmljZS5wcm90bxIMbWVtb3MuYXBpLnYxIp4BCghTaG9ydGN1dBIRCgRuYW1lGAEgASgJQgPgQQgSEgoFdGl0bGUYAiABKAlCA+BBAhITCgZmaWx0ZXIYAyABKAlCA+BBATpW6kFTChVtZW1vcy5hcGkudjEvU2hvcnRjdXQSJXVzZXJzL3t1c2VybmFtZX0vc2hvcnRjdXRzL3tzaG9ydGN1dH0qCXNob3J0Y3V0czIIc2hvcnRjdXQiRQoUTGlzdFNob3J0Y3V0c1JlcXVlc3QSLQoGcGFyZW50GAEgASgJQh3gQQL6QRcSFW1lbW9zLmFwaS52MS9TaG9ydGN1dCJCChVMaXN0U2hvcnRjdXRzUmVzcG9uc2USKQoJc2hvcnRjdXRzGAEgAygLMhYubWVtb3MuYXBpLnYxLlNob3J0Y3V0IkEKEkdldFNob3J0Y3V0UmVxdWVzdBIrCgRuYW1lGAEgASgJQh3gQQL6QRcKFW1lbW9zLmFwaS52MS9TaG9ydGN1dCKRAQoVQ3JlYXRlU2hvcnRjdXRSZXF1ZXN0Ei0KBnBhcmVudBgBIAEoCUId4EEC+kEXEhVtZW1vcy5hcGkudjEvU2hvcnRjdXQSLQoIc2hvcnRjdXQYAiABKAsyFi5tZW1vcy5hcGkudjEuU2hvcnRjdXRCA+BBAhIaCg12YWxpZGF0ZV9vbmx5GAMgASgIQgPgQQEifAoVVXBkYXRlU2hvcnRjdXRSZXF1ZXN0Ei0KCHNob3J0Y3V0GAEgASgLMhYubWVtb3MuYXBpLnYxLlNob3J0Y3V0QgPgQQISNAoLdXBkYXRlX21hc2sYAiABKAsyGi5nb29nbGUucHJvdG9idWYuRmllbGRNYXNrQgPgQQEiRAoVRGVsZXRlU2hvcnRjdXRSZXF1ZXN0EisKBG5hbWUYASABKAlCHeBBAvpBFwoVbWVtb3MuYXBpLnYxL1Nob3J0Y3V0Mt4FCg9TaG9ydGN1dFNlcnZpY2USjQEKDUxpc3RTaG9ydGN1dHMSIi5tZW1vcy5hcGkudjEuTGlzdFNob3J0Y3V0c1JlcXVlc3QaIy5tZW1vcy5hcGkudjEuTGlzdFNob3J0Y3V0c1Jlc3BvbnNlIjPaQQZwYXJlbnSC0+STAiQSIi9hcGkvdjEve3BhcmVudD11c2Vycy8qfS9zaG9ydGN1dHMSegoLR2V0U2hvcnRjdXQSIC5tZW1vcy5hcGkudjEuR2V0U2hvcnRjdXRSZXF1ZXN0GhYubWVtb3MuYXBpLnYxLlNob3J0Y3V0IjHaQQRuYW1lgtPkkwIkEiIvYXBpL3YxL3tuYW1lPXVzZXJzLyovc2hvcnRjdXRzLyp9EpUBCg5DcmVhdGVTaG9ydGN1dBIjLm1lbW9zLmFwaS52MS5DcmVhdGVTaG9ydGN1dFJlcXVlc3QaFi5tZW1vcy5hcGkudjEuU2hvcnRjdXQiRtpBD3BhcmVudCxzaG9ydGN1dILT5JMCLjoIc2hvcnRjdXQiIi9hcGkvdjEve3BhcmVudD11c2Vycy8qfS9zaG9ydGN1dHMSowEKDlVwZGF0ZVNob3J0Y3V0EiMubWVtb3MuYXBpLnYxLlVwZGF0ZVNob3J0Y3V0UmVxdWVzdBoWLm1lbW9zLmFwaS52MS5TaG9ydGN1dCJU2kEUc2hvcnRjdXQsdXBkYXRlX21hc2uC0+STAjc6CHNob3J0Y3V0MisvYXBpL3YxL3tzaG9ydGN1dC5uYW1lPXVzZXJzLyovc2hvcnRjdXRzLyp9EoABCg5EZWxldGVTaG9ydGN1dBIjLm1lbW9zLmFwaS52MS5EZWxldGVTaG9ydGN1dFJlcXVlc3QaFi5nb29nbGUucHJvdG9idWYuRW1wdHkiMdpBBG5hbWWC0+STAiQqIi9hcGkvdjEve25hbWU9dXNlcnMvKi9zaG9ydGN1dHMvKn1CrAEKEGNvbS5tZW1vcy5hcGkudjFCFFNob3J0Y3V0U2VydmljZVByb3RvUAFaMGdpdGh1Yi5jb20vdXNlbWVtb3MvbWVtb3MvcHJvdG8vZ2VuL2FwaS92MTthcGl2MaICA01BWKoCDE1lbW9zLkFwaS5WMcoCDE1lbW9zXEFwaVxWMeICGE1lbW9zXEFwaVxWMVxHUEJNZXRhZGF0YeoCDk1lbW9zOjpBcGk6OlYxYgZwcm90bzM", [file_google_api_annotations, file_google_api_client, file_google_api_field_behavior, file_google_api_resource, file_google_protobuf_empty, file_google_protobuf_field_mask]);
|
||||
|
||||
/**
|
||||
* @generated from message memos.api.v1.Shortcut
|
||||
|
|
@ -24,7 +24,7 @@ export const file_api_v1_shortcut_service: GenFile = /*@__PURE__*/
|
|||
export type Shortcut = Message<"memos.api.v1.Shortcut"> & {
|
||||
/**
|
||||
* The resource name of the shortcut.
|
||||
* Format: users/{user}/shortcuts/{shortcut}
|
||||
* Format: users/{username}/shortcuts/{shortcut}
|
||||
*
|
||||
* @generated from field: string name = 1;
|
||||
*/
|
||||
|
|
@ -58,7 +58,7 @@ export const ShortcutSchema: GenMessage<Shortcut> = /*@__PURE__*/
|
|||
export type ListShortcutsRequest = Message<"memos.api.v1.ListShortcutsRequest"> & {
|
||||
/**
|
||||
* Required. The parent resource where shortcuts are listed.
|
||||
* Format: users/{user}
|
||||
* Format: users/{username}
|
||||
*
|
||||
* @generated from field: string parent = 1;
|
||||
*/
|
||||
|
|
@ -97,7 +97,7 @@ export const ListShortcutsResponseSchema: GenMessage<ListShortcutsResponse> = /*
|
|||
export type GetShortcutRequest = Message<"memos.api.v1.GetShortcutRequest"> & {
|
||||
/**
|
||||
* Required. The resource name of the shortcut to retrieve.
|
||||
* Format: users/{user}/shortcuts/{shortcut}
|
||||
* Format: users/{username}/shortcuts/{shortcut}
|
||||
*
|
||||
* @generated from field: string name = 1;
|
||||
*/
|
||||
|
|
@ -117,7 +117,7 @@ export const GetShortcutRequestSchema: GenMessage<GetShortcutRequest> = /*@__PUR
|
|||
export type CreateShortcutRequest = Message<"memos.api.v1.CreateShortcutRequest"> & {
|
||||
/**
|
||||
* Required. The parent resource where this shortcut will be created.
|
||||
* Format: users/{user}
|
||||
* Format: users/{username}
|
||||
*
|
||||
* @generated from field: string parent = 1;
|
||||
*/
|
||||
|
|
@ -177,7 +177,7 @@ export const UpdateShortcutRequestSchema: GenMessage<UpdateShortcutRequest> = /*
|
|||
export type DeleteShortcutRequest = Message<"memos.api.v1.DeleteShortcutRequest"> & {
|
||||
/**
|
||||
* Required. The resource name of the shortcut to delete.
|
||||
* Format: users/{user}/shortcuts/{shortcut}
|
||||
* Format: users/{username}/shortcuts/{shortcut}
|
||||
*
|
||||
* @generated from field: string name = 1;
|
||||
*/
|
||||
|
|
|
|||
File diff suppressed because one or more lines are too long
Loading…
Reference in New Issue