diff --git a/server/router/api/v1/test/memo_share_service_test.go b/server/router/api/v1/test/memo_share_service_test.go index 110b83a35..946597d92 100644 --- a/server/router/api/v1/test/memo_share_service_test.go +++ b/server/router/api/v1/test/memo_share_service_test.go @@ -4,12 +4,14 @@ import ( "context" "strings" "testing" + "time" "github.com/stretchr/testify/require" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" apiv1 "github.com/usememos/memos/proto/gen/api/v1" + "github.com/usememos/memos/store" ) func TestDeleteMemoShare_VerifiesShareBelongsToMemo(t *testing.T) { @@ -107,3 +109,107 @@ func TestGetMemoByShare_IncludesReactions(t *testing.T) { require.Equal(t, "👍", sharedMemo.Reactions[0].ReactionType) require.Equal(t, memo.Name, sharedMemo.Reactions[0].ContentId) } + +func TestGetMemoByShare_ReturnsNotFoundForUnknownShare(t *testing.T) { + ctx := context.Background() + + ts := NewTestService(t) + defer ts.Cleanup() + + _, err := ts.Service.GetMemoByShare(ctx, &apiv1.GetMemoByShareRequest{ + ShareId: "missing-share-token", + }) + require.Error(t, err) + require.Equal(t, codes.NotFound, status.Code(err)) +} + +func TestGetMemoByShare_ReturnsNotFoundForExpiredShare(t *testing.T) { + ctx := context.Background() + + ts := NewTestService(t) + defer ts.Cleanup() + + user, err := ts.CreateRegularUser(ctx, "share-expired") + require.NoError(t, err) + userCtx := ts.CreateUserContext(ctx, user.ID) + + memo, err := ts.Service.CreateMemo(userCtx, &apiv1.CreateMemoRequest{ + Memo: &apiv1.Memo{ + Content: "memo with expired share", + Visibility: apiv1.Visibility_PRIVATE, + }, + }) + require.NoError(t, err) + + expiredTs := time.Now().Add(-time.Hour).Unix() + expiredShare, err := ts.Store.CreateMemoShare(ctx, &store.MemoShare{ + UID: "expired-share-token", + MemoID: parseMemoIDFromNameForTest(t, ts, memo.Name), + CreatorID: user.ID, + ExpiresTs: &expiredTs, + }) + require.NoError(t, err) + + _, err = ts.Service.GetMemoByShare(ctx, &apiv1.GetMemoByShareRequest{ + ShareId: expiredShare.UID, + }) + require.Error(t, err) + require.Equal(t, codes.NotFound, status.Code(err)) +} + +func TestGetMemoByShare_ReturnsNotFoundForArchivedMemo(t *testing.T) { + ctx := context.Background() + + ts := NewTestService(t) + defer ts.Cleanup() + + user, err := ts.CreateRegularUser(ctx, "share-archived") + require.NoError(t, err) + userCtx := ts.CreateUserContext(ctx, user.ID) + + memoResp, err := ts.Service.CreateMemo(userCtx, &apiv1.CreateMemoRequest{ + Memo: &apiv1.Memo{ + Content: "memo that will be archived", + Visibility: apiv1.Visibility_PRIVATE, + }, + }) + require.NoError(t, err) + + share, err := ts.Service.CreateMemoShare(userCtx, &apiv1.CreateMemoShareRequest{ + Parent: memoResp.Name, + MemoShare: &apiv1.MemoShare{}, + }) + require.NoError(t, err) + + memoID := parseMemoIDFromNameForTest(t, ts, memoResp.Name) + memo, err := ts.Store.GetMemo(ctx, &store.FindMemo{ID: &memoID}) + require.NoError(t, err) + require.NotNil(t, memo) + + archived := store.Archived + err = ts.Store.UpdateMemo(ctx, &store.UpdateMemo{ + ID: memo.ID, + RowStatus: &archived, + }) + require.NoError(t, err) + + shareToken := share.Name[strings.LastIndex(share.Name, "/")+1:] + _, err = ts.Service.GetMemoByShare(ctx, &apiv1.GetMemoByShareRequest{ + ShareId: shareToken, + }) + require.Error(t, err) + require.Equal(t, codes.NotFound, status.Code(err)) +} + +func parseMemoIDFromNameForTest(t *testing.T, ts *TestService, memoName string) int32 { + t.Helper() + + memoUID, ok := strings.CutPrefix(memoName, "memos/") + require.True(t, ok, "memo name must start with memos/: %s", memoName) + + memo, err := ts.Store.GetMemo(context.Background(), &store.FindMemo{UID: &memoUID}) + require.NoError(t, err) + require.NotNil(t, memo) + + return memo.ID +} diff --git a/store/test/containers.go b/store/test/containers.go index 0b98c5c2b..e9760b6a1 100644 --- a/store/test/containers.go +++ b/store/test/containers.go @@ -4,6 +4,8 @@ import ( "context" "database/sql" "fmt" + "net" + "net/url" "os" "strings" "sync" @@ -12,6 +14,7 @@ import ( "time" "github.com/docker/docker/api/types/container" + mysqldriver "github.com/go-sql-driver/mysql" "github.com/pkg/errors" "github.com/testcontainers/testcontainers-go" "github.com/testcontainers/testcontainers-go/modules/mysql" @@ -20,7 +23,6 @@ import ( "github.com/testcontainers/testcontainers-go/wait" // Database drivers for connection verification. - _ "github.com/go-sql-driver/mysql" _ "github.com/lib/pq" ) @@ -31,6 +33,9 @@ const ( // Memos container settings for migration testing. MemosDockerImage = "neosmemo/memos" StableMemosVersion = "stable" // Always points to the latest stable release + + mysqlNetworkAlias = "memos-mysql" + postgresNetworkAlias = "memos-postgres" ) var ( @@ -62,12 +67,23 @@ func getTestNetwork(ctx context.Context) (*testcontainers.DockerNetwork, error) return testDockerNetwork.Load(), networkErr } +func requireTestNetwork(ctx context.Context) (*testcontainers.DockerNetwork, error) { + nw, err := getTestNetwork(ctx) + if err != nil { + return nil, errors.Wrap(err, "failed to create test network") + } + if nw == nil { + return nil, errors.New("test network is unavailable") + } + return nw, nil +} + // GetMySQLDSN starts a MySQL container (if not already running) and creates a fresh database for this test. func GetMySQLDSN(t *testing.T) string { ctx := context.Background() mysqlOnce.Do(func() { - nw, err := getTestNetwork(ctx) + nw, err := requireTestNetwork(ctx) if err != nil { t.Fatalf("failed to create test network: %v", err) } @@ -86,7 +102,7 @@ func GetMySQLDSN(t *testing.T) string { wait.ForListeningPort("3306/tcp"), ).WithDeadline(120*time.Second), ), - network.WithNetwork(nil, nw), + network.WithNetwork([]string{mysqlNetworkAlias}, nw), ) if err != nil { t.Fatalf("failed to start MySQL container: %v", err) @@ -167,7 +183,7 @@ func GetPostgresDSN(t *testing.T) string { ctx := context.Background() postgresOnce.Do(func() { - nw, err := getTestNetwork(ctx) + nw, err := requireTestNetwork(ctx) if err != nil { t.Fatalf("failed to create test network: %v", err) } @@ -183,7 +199,7 @@ func GetPostgresDSN(t *testing.T) string { wait.ForListeningPort("5432/tcp"), ).WithDeadline(120*time.Second), ), - network.WithNetwork(nil, nw), + network.WithNetwork([]string{postgresNetworkAlias}, nw), ) if err != nil { t.Fatalf("failed to start PostgreSQL container: %v", err) @@ -264,6 +280,11 @@ func StartMemosContainer(ctx context.Context, cfg MemosContainerConfig) (testcon "MEMOS_MODE": "prod", } + nw, err := requireTestNetwork(ctx) + if err != nil { + return nil, err + } + var opts []testcontainers.ContainerCustomizer switch cfg.Driver { @@ -272,6 +293,12 @@ func StartMemosContainer(ctx context.Context, cfg MemosContainerConfig) (testcon opts = append(opts, testcontainers.WithHostConfigModifier(func(hc *container.HostConfig) { hc.Binds = append(hc.Binds, fmt.Sprintf("%s:%s", cfg.DataDir, "/var/opt/memos")) })) + case "mysql", "postgres": + if cfg.DSN == "" { + return nil, errors.Errorf("dsn is required for %s migration testing", cfg.Driver) + } + env["MEMOS_DRIVER"] = cfg.Driver + env["MEMOS_DSN"] = cfg.DSN default: return nil, errors.Errorf("unsupported driver for migration testing: %s", cfg.Driver) } @@ -303,6 +330,7 @@ func StartMemosContainer(ctx context.Context, cfg MemosContainerConfig) (testcon } // Apply options + opts = append(opts, network.WithNetwork(nil, nw)) for _, opt := range opts { if err := opt.Customize(&genericReq); err != nil { return nil, errors.Wrap(err, "failed to apply container option") @@ -316,3 +344,27 @@ func StartMemosContainer(ctx context.Context, cfg MemosContainerConfig) (testcon return ctr, nil } + +func getContainerDSN(driver, hostDSN string) (string, error) { + switch driver { + case "mysql": + cfg, err := mysqldriver.ParseDSN(hostDSN) + if err != nil { + return "", errors.Wrap(err, "failed to parse mysql dsn") + } + cfg.Net = "tcp" + cfg.Addr = net.JoinHostPort(mysqlNetworkAlias, "3306") + return cfg.FormatDSN(), nil + case "postgres": + u, err := url.Parse(hostDSN) + if err != nil { + return "", errors.Wrap(err, "failed to parse postgres dsn") + } + u.Host = net.JoinHostPort(postgresNetworkAlias, "5432") + return u.String(), nil + case "sqlite": + return hostDSN, nil + default: + return "", errors.Errorf("unsupported driver for container dsn: %s", driver) + } +} diff --git a/store/test/migrator_upgrade_test.go b/store/test/migrator_upgrade_test.go new file mode 100644 index 000000000..4b8254ed1 --- /dev/null +++ b/store/test/migrator_upgrade_test.go @@ -0,0 +1,274 @@ +package test + +import ( + "context" + "database/sql" + "fmt" + "os" + "strings" + "testing" + "time" + + "github.com/pkg/errors" + "github.com/stretchr/testify/require" + + storepb "github.com/usememos/memos/proto/gen/store" + "github.com/usememos/memos/store" +) + +func TestMigrationFromV0262PreservesLegacyData(t *testing.T) { + if testing.Short() { + t.Skip("skipping container-based upgrade test in short mode") + } + if os.Getenv("SKIP_CONTAINER_TESTS") == "1" { + t.Skip("skipping container-based test (SKIP_CONTAINER_TESTS=1)") + } + + ctx := context.Background() + driver := getDriverFromEnv() + + cfg, hostDSN := prepareV0262MigrationTest(t, driver) + t.Logf("Starting Memos %s container for %s schema bootstrap...", cfg.Version, driver) + container, err := StartMemosContainer(ctx, cfg) + require.NoError(t, err, "failed to start v0.26.2 memos container") + t.Cleanup(func() { + if container != nil { + _ = container.Terminate(ctx) + } + }) + + legacyStore := NewTestingStoreWithDSN(ctx, t, driver, hostDSN) + require.Eventually(t, func() bool { + setting, err := legacyStore.GetInstanceBasicSetting(ctx) + return err == nil && setting != nil && setting.SchemaVersion != "" + }, 45*time.Second, 500*time.Millisecond, "legacy schema should be initialized by old container") + + settingBeforeSeed, err := legacyStore.GetInstanceBasicSetting(ctx) + require.NoError(t, err) + t.Logf("Legacy schema version before migration: %s", settingBeforeSeed.SchemaVersion) + + err = container.Terminate(ctx) + require.NoError(t, err, "failed to stop v0.26.2 memos container") + container = nil + + db := openMigrationSQLDB(t, driver, hostDSN) + defer db.Close() + + seedLegacyMigrationData(ctx, t, driver, db) + + count, err := countSystemSetting(ctx, db, "STORAGE") + require.NoError(t, err) + require.Zero(t, count, "v0.26.2 database should not have a STORAGE setting before migration") + + ts := NewTestingStoreWithDSN(ctx, t, driver, hostDSN) + err = ts.Migrate(ctx) + require.NoError(t, err, "migration from v0.26.2 should succeed for %s", driver) + + currentVersion, err := ts.GetCurrentSchemaVersion() + require.NoError(t, err) + currentSetting, err := ts.GetInstanceBasicSetting(ctx) + require.NoError(t, err) + require.Equal(t, currentVersion, currentSetting.SchemaVersion, "schema version should be updated") + + storageSetting, err := ts.GetInstanceStorageSetting(ctx) + require.NoError(t, err) + require.Equal(t, storepb.InstanceStorageSetting_DATABASE, storageSetting.StorageType, "existing installs should stay on DATABASE storage") + + idps, err := ts.ListIdentityProviders(ctx, &store.FindIdentityProvider{}) + require.NoError(t, err) + require.Len(t, idps, 2) + idpUIDsByName := map[string]string{} + for _, idp := range idps { + idpUIDsByName[idp.Name] = idp.Uid + } + require.Equal(t, "00000191", idpUIDsByName["Legacy Google"]) + require.Equal(t, "00000192", idpUIDsByName["Legacy GitHub"]) + + inboxes, err := ts.ListInboxes(ctx, &store.FindInbox{}) + require.NoError(t, err) + require.Len(t, inboxes, 1) + require.NotNil(t, inboxes[0].Message) + require.Equal(t, storepb.InboxMessage_MEMO_COMMENT, inboxes[0].Message.Type) + require.Equal(t, int32(102), inboxes[0].Message.GetMemoComment().MemoId) + require.Equal(t, int32(101), inboxes[0].Message.GetMemoComment().RelatedMemoId) + + activityExists, err := tableExists(ctx, db, driver, "activity") + require.NoError(t, err) + require.False(t, activityExists, "activity table should be removed after migration") + + memoShareExists, err := tableExists(ctx, db, driver, "memo_share") + require.NoError(t, err) + require.True(t, memoShareExists, "memo_share table should be created") + + share, err := ts.CreateMemoShare(ctx, &store.MemoShare{ + UID: "post-upgrade-share", + MemoID: 101, + CreatorID: 11, + }) + require.NoError(t, err) + require.Equal(t, "post-upgrade-share", share.UID) + + postUpgradeUser, err := createTestingUserWithRole(ctx, ts, "postupgrade", store.RoleUser) + require.NoError(t, err) + postUpgradeMemo, err := ts.CreateMemo(ctx, &store.Memo{ + UID: "post-upgrade-memo-v0262", + CreatorID: postUpgradeUser.ID, + Content: "created after v0.26.2 migration", + Visibility: store.Public, + }) + require.NoError(t, err) + require.Equal(t, "created after v0.26.2 migration", postUpgradeMemo.Content) +} + +func prepareV0262MigrationTest(t *testing.T, driver string) (MemosContainerConfig, string) { + t.Helper() + + const version = "0.26.2" + + switch driver { + case "sqlite": + dataDir := t.TempDir() + return MemosContainerConfig{ + Version: version, + Driver: driver, + DataDir: dataDir, + }, fmt.Sprintf("%s/memos_prod.db", dataDir) + case "mysql": + hostDSN := GetMySQLDSN(t) + containerDSN, err := getContainerDSN(driver, hostDSN) + require.NoError(t, err) + return MemosContainerConfig{ + Version: version, + Driver: driver, + DSN: containerDSN, + }, hostDSN + case "postgres": + hostDSN := GetPostgresDSN(t) + containerDSN, err := getContainerDSN(driver, hostDSN) + require.NoError(t, err) + return MemosContainerConfig{ + Version: version, + Driver: driver, + DSN: containerDSN, + }, hostDSN + default: + t.Fatalf("unsupported driver: %s", driver) + return MemosContainerConfig{}, "" + } +} + +func openMigrationSQLDB(t *testing.T, driver, dsn string) *sql.DB { + t.Helper() + + db, err := sql.Open(driver, dsn) + require.NoError(t, err) + require.NoError(t, db.Ping()) + return db +} + +func seedLegacyMigrationData(ctx context.Context, t *testing.T, driver string, db *sql.DB) { + t.Helper() + + execMigrationSQL(t, db, legacyInsertUserSQL(driver, 11, "owner")) + execMigrationSQL(t, db, legacyInsertUserSQL(driver, 12, "commenter")) + execMigrationSQL(t, db, legacyInsertMemoSQL(101, 11, "legacy-parent", "parent memo")) + execMigrationSQL(t, db, legacyInsertMemoSQL(102, 12, "legacy-comment", "comment memo")) + execMigrationSQL(t, db, legacyInsertActivitySQL(201, 12)) + execMigrationSQL(t, db, legacyInsertInboxSQL(301, 12, 11, 201)) + execMigrationSQL(t, db, legacyInsertIDPSQL(401, "Legacy Google")) + execMigrationSQL(t, db, legacyInsertIDPSQL(402, "Legacy GitHub")) + + var message string + err := db.QueryRowContext(ctx, "SELECT message FROM inbox WHERE id = 301").Scan(&message) + require.NoError(t, err) + require.Contains(t, message, "\"activityId\":201") + require.NotContains(t, message, "\"memoComment\"") +} + +func execMigrationSQL(t *testing.T, db *sql.DB, query string) { + t.Helper() + _, err := db.Exec(query) + require.NoError(t, err, "failed to execute SQL: %s", query) +} + +func legacyInsertUserSQL(driver string, id int, username string) string { + table := "user" + switch driver { + case "mysql": + table = "`user`" + case "postgres", "sqlite": + table = `"user"` + default: + // Keep the unquoted fallback for unknown test drivers. + } + + return fmt.Sprintf( + "INSERT INTO %s (id, username, role, email, nickname, password_hash, avatar_url, description) VALUES (%d, '%s', 'USER', '%s@example.com', '%s', 'legacy-hash', '', 'legacy user')", + table, id, username, username, username, + ) +} + +func legacyInsertMemoSQL(id, creatorID int, uid, content string) string { + payload := "{}" + return fmt.Sprintf( + "INSERT INTO memo (id, uid, creator_id, content, visibility, payload) VALUES (%d, '%s', %d, '%s', 'PRIVATE', '%s')", + id, uid, creatorID, content, payload, + ) +} + +func legacyInsertActivitySQL(id, creatorID int) string { + payload := `{"memoComment":{"memoId":102,"relatedMemoId":101}}` + return fmt.Sprintf( + "INSERT INTO activity (id, creator_id, type, level, payload) VALUES (%d, %d, 'MEMO_COMMENT', 'INFO', '%s')", + id, creatorID, payload, + ) +} + +func legacyInsertInboxSQL(id, senderID, receiverID, activityID int) string { + message := fmt.Sprintf(`{"type":"MEMO_COMMENT","activityId":%d}`, activityID) + return fmt.Sprintf( + "INSERT INTO inbox (id, sender_id, receiver_id, status, message) VALUES (%d, %d, %d, 'UNREAD', '%s')", + id, senderID, receiverID, message, + ) +} + +func legacyInsertIDPSQL(id int, name string) string { + config := `{"clientId":"legacy-client","clientSecret":"legacy-secret","authUrl":"https://example.com/auth","tokenUrl":"https://example.com/token","userInfoUrl":"https://example.com/userinfo"}` + return fmt.Sprintf( + "INSERT INTO idp (id, name, type, identifier_filter, config) VALUES (%d, '%s', 'OAUTH2', '', '%s')", + id, name, config, + ) +} + +func countSystemSetting(ctx context.Context, db *sql.DB, name string) (int, error) { + var count int + err := db.QueryRowContext(ctx, "SELECT COUNT(*) FROM system_setting WHERE name = ?", name).Scan(&count) + if err == nil { + return count, nil + } + + err = db.QueryRowContext(ctx, "SELECT COUNT(*) FROM system_setting WHERE name = $1", name).Scan(&count) + return count, err +} + +func tableExists(ctx context.Context, db *sql.DB, driver, table string) (bool, error) { + switch driver { + case "sqlite": + var name string + err := db.QueryRowContext(ctx, "SELECT name FROM sqlite_master WHERE type = 'table' AND name = ?", table).Scan(&name) + if err == sql.ErrNoRows { + return false, nil + } + return err == nil, err + case "mysql": + var count int + err := db.QueryRowContext(ctx, "SELECT COUNT(*) FROM information_schema.tables WHERE table_schema = DATABASE() AND table_name = ?", table).Scan(&count) + return count > 0, err + case "postgres": + var regclass sql.NullString + err := db.QueryRowContext(ctx, "SELECT to_regclass($1)", "public."+table).Scan(®class) + return regclass.Valid && strings.EqualFold(regclass.String, table), err + default: + return false, errors.Errorf("unsupported driver: %s", driver) + } +}