mirror of https://github.com/usememos/memos.git
chore: add migration upgrade coverage (#5796)
This commit is contained in:
parent
e520b637fd
commit
7c708ee27e
|
|
@ -4,12 +4,14 @@ import (
|
|||
"context"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/status"
|
||||
|
||||
apiv1 "github.com/usememos/memos/proto/gen/api/v1"
|
||||
"github.com/usememos/memos/store"
|
||||
)
|
||||
|
||||
func TestDeleteMemoShare_VerifiesShareBelongsToMemo(t *testing.T) {
|
||||
|
|
@ -107,3 +109,107 @@ func TestGetMemoByShare_IncludesReactions(t *testing.T) {
|
|||
require.Equal(t, "👍", sharedMemo.Reactions[0].ReactionType)
|
||||
require.Equal(t, memo.Name, sharedMemo.Reactions[0].ContentId)
|
||||
}
|
||||
|
||||
func TestGetMemoByShare_ReturnsNotFoundForUnknownShare(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
ts := NewTestService(t)
|
||||
defer ts.Cleanup()
|
||||
|
||||
_, err := ts.Service.GetMemoByShare(ctx, &apiv1.GetMemoByShareRequest{
|
||||
ShareId: "missing-share-token",
|
||||
})
|
||||
require.Error(t, err)
|
||||
require.Equal(t, codes.NotFound, status.Code(err))
|
||||
}
|
||||
|
||||
func TestGetMemoByShare_ReturnsNotFoundForExpiredShare(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
ts := NewTestService(t)
|
||||
defer ts.Cleanup()
|
||||
|
||||
user, err := ts.CreateRegularUser(ctx, "share-expired")
|
||||
require.NoError(t, err)
|
||||
userCtx := ts.CreateUserContext(ctx, user.ID)
|
||||
|
||||
memo, err := ts.Service.CreateMemo(userCtx, &apiv1.CreateMemoRequest{
|
||||
Memo: &apiv1.Memo{
|
||||
Content: "memo with expired share",
|
||||
Visibility: apiv1.Visibility_PRIVATE,
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
expiredTs := time.Now().Add(-time.Hour).Unix()
|
||||
expiredShare, err := ts.Store.CreateMemoShare(ctx, &store.MemoShare{
|
||||
UID: "expired-share-token",
|
||||
MemoID: parseMemoIDFromNameForTest(t, ts, memo.Name),
|
||||
CreatorID: user.ID,
|
||||
ExpiresTs: &expiredTs,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = ts.Service.GetMemoByShare(ctx, &apiv1.GetMemoByShareRequest{
|
||||
ShareId: expiredShare.UID,
|
||||
})
|
||||
require.Error(t, err)
|
||||
require.Equal(t, codes.NotFound, status.Code(err))
|
||||
}
|
||||
|
||||
func TestGetMemoByShare_ReturnsNotFoundForArchivedMemo(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
ts := NewTestService(t)
|
||||
defer ts.Cleanup()
|
||||
|
||||
user, err := ts.CreateRegularUser(ctx, "share-archived")
|
||||
require.NoError(t, err)
|
||||
userCtx := ts.CreateUserContext(ctx, user.ID)
|
||||
|
||||
memoResp, err := ts.Service.CreateMemo(userCtx, &apiv1.CreateMemoRequest{
|
||||
Memo: &apiv1.Memo{
|
||||
Content: "memo that will be archived",
|
||||
Visibility: apiv1.Visibility_PRIVATE,
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
share, err := ts.Service.CreateMemoShare(userCtx, &apiv1.CreateMemoShareRequest{
|
||||
Parent: memoResp.Name,
|
||||
MemoShare: &apiv1.MemoShare{},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
memoID := parseMemoIDFromNameForTest(t, ts, memoResp.Name)
|
||||
memo, err := ts.Store.GetMemo(ctx, &store.FindMemo{ID: &memoID})
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, memo)
|
||||
|
||||
archived := store.Archived
|
||||
err = ts.Store.UpdateMemo(ctx, &store.UpdateMemo{
|
||||
ID: memo.ID,
|
||||
RowStatus: &archived,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
shareToken := share.Name[strings.LastIndex(share.Name, "/")+1:]
|
||||
_, err = ts.Service.GetMemoByShare(ctx, &apiv1.GetMemoByShareRequest{
|
||||
ShareId: shareToken,
|
||||
})
|
||||
require.Error(t, err)
|
||||
require.Equal(t, codes.NotFound, status.Code(err))
|
||||
}
|
||||
|
||||
func parseMemoIDFromNameForTest(t *testing.T, ts *TestService, memoName string) int32 {
|
||||
t.Helper()
|
||||
|
||||
memoUID, ok := strings.CutPrefix(memoName, "memos/")
|
||||
require.True(t, ok, "memo name must start with memos/: %s", memoName)
|
||||
|
||||
memo, err := ts.Store.GetMemo(context.Background(), &store.FindMemo{UID: &memoUID})
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, memo)
|
||||
|
||||
return memo.ID
|
||||
}
|
||||
|
|
|
|||
|
|
@ -4,6 +4,8 @@ import (
|
|||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/url"
|
||||
"os"
|
||||
"strings"
|
||||
"sync"
|
||||
|
|
@ -12,6 +14,7 @@ import (
|
|||
"time"
|
||||
|
||||
"github.com/docker/docker/api/types/container"
|
||||
mysqldriver "github.com/go-sql-driver/mysql"
|
||||
"github.com/pkg/errors"
|
||||
"github.com/testcontainers/testcontainers-go"
|
||||
"github.com/testcontainers/testcontainers-go/modules/mysql"
|
||||
|
|
@ -20,7 +23,6 @@ import (
|
|||
"github.com/testcontainers/testcontainers-go/wait"
|
||||
|
||||
// Database drivers for connection verification.
|
||||
_ "github.com/go-sql-driver/mysql"
|
||||
_ "github.com/lib/pq"
|
||||
)
|
||||
|
||||
|
|
@ -31,6 +33,9 @@ const (
|
|||
// Memos container settings for migration testing.
|
||||
MemosDockerImage = "neosmemo/memos"
|
||||
StableMemosVersion = "stable" // Always points to the latest stable release
|
||||
|
||||
mysqlNetworkAlias = "memos-mysql"
|
||||
postgresNetworkAlias = "memos-postgres"
|
||||
)
|
||||
|
||||
var (
|
||||
|
|
@ -62,12 +67,23 @@ func getTestNetwork(ctx context.Context) (*testcontainers.DockerNetwork, error)
|
|||
return testDockerNetwork.Load(), networkErr
|
||||
}
|
||||
|
||||
func requireTestNetwork(ctx context.Context) (*testcontainers.DockerNetwork, error) {
|
||||
nw, err := getTestNetwork(ctx)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "failed to create test network")
|
||||
}
|
||||
if nw == nil {
|
||||
return nil, errors.New("test network is unavailable")
|
||||
}
|
||||
return nw, nil
|
||||
}
|
||||
|
||||
// GetMySQLDSN starts a MySQL container (if not already running) and creates a fresh database for this test.
|
||||
func GetMySQLDSN(t *testing.T) string {
|
||||
ctx := context.Background()
|
||||
|
||||
mysqlOnce.Do(func() {
|
||||
nw, err := getTestNetwork(ctx)
|
||||
nw, err := requireTestNetwork(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create test network: %v", err)
|
||||
}
|
||||
|
|
@ -86,7 +102,7 @@ func GetMySQLDSN(t *testing.T) string {
|
|||
wait.ForListeningPort("3306/tcp"),
|
||||
).WithDeadline(120*time.Second),
|
||||
),
|
||||
network.WithNetwork(nil, nw),
|
||||
network.WithNetwork([]string{mysqlNetworkAlias}, nw),
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to start MySQL container: %v", err)
|
||||
|
|
@ -167,7 +183,7 @@ func GetPostgresDSN(t *testing.T) string {
|
|||
ctx := context.Background()
|
||||
|
||||
postgresOnce.Do(func() {
|
||||
nw, err := getTestNetwork(ctx)
|
||||
nw, err := requireTestNetwork(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create test network: %v", err)
|
||||
}
|
||||
|
|
@ -183,7 +199,7 @@ func GetPostgresDSN(t *testing.T) string {
|
|||
wait.ForListeningPort("5432/tcp"),
|
||||
).WithDeadline(120*time.Second),
|
||||
),
|
||||
network.WithNetwork(nil, nw),
|
||||
network.WithNetwork([]string{postgresNetworkAlias}, nw),
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to start PostgreSQL container: %v", err)
|
||||
|
|
@ -264,6 +280,11 @@ func StartMemosContainer(ctx context.Context, cfg MemosContainerConfig) (testcon
|
|||
"MEMOS_MODE": "prod",
|
||||
}
|
||||
|
||||
nw, err := requireTestNetwork(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var opts []testcontainers.ContainerCustomizer
|
||||
|
||||
switch cfg.Driver {
|
||||
|
|
@ -272,6 +293,12 @@ func StartMemosContainer(ctx context.Context, cfg MemosContainerConfig) (testcon
|
|||
opts = append(opts, testcontainers.WithHostConfigModifier(func(hc *container.HostConfig) {
|
||||
hc.Binds = append(hc.Binds, fmt.Sprintf("%s:%s", cfg.DataDir, "/var/opt/memos"))
|
||||
}))
|
||||
case "mysql", "postgres":
|
||||
if cfg.DSN == "" {
|
||||
return nil, errors.Errorf("dsn is required for %s migration testing", cfg.Driver)
|
||||
}
|
||||
env["MEMOS_DRIVER"] = cfg.Driver
|
||||
env["MEMOS_DSN"] = cfg.DSN
|
||||
default:
|
||||
return nil, errors.Errorf("unsupported driver for migration testing: %s", cfg.Driver)
|
||||
}
|
||||
|
|
@ -303,6 +330,7 @@ func StartMemosContainer(ctx context.Context, cfg MemosContainerConfig) (testcon
|
|||
}
|
||||
|
||||
// Apply options
|
||||
opts = append(opts, network.WithNetwork(nil, nw))
|
||||
for _, opt := range opts {
|
||||
if err := opt.Customize(&genericReq); err != nil {
|
||||
return nil, errors.Wrap(err, "failed to apply container option")
|
||||
|
|
@ -316,3 +344,27 @@ func StartMemosContainer(ctx context.Context, cfg MemosContainerConfig) (testcon
|
|||
|
||||
return ctr, nil
|
||||
}
|
||||
|
||||
func getContainerDSN(driver, hostDSN string) (string, error) {
|
||||
switch driver {
|
||||
case "mysql":
|
||||
cfg, err := mysqldriver.ParseDSN(hostDSN)
|
||||
if err != nil {
|
||||
return "", errors.Wrap(err, "failed to parse mysql dsn")
|
||||
}
|
||||
cfg.Net = "tcp"
|
||||
cfg.Addr = net.JoinHostPort(mysqlNetworkAlias, "3306")
|
||||
return cfg.FormatDSN(), nil
|
||||
case "postgres":
|
||||
u, err := url.Parse(hostDSN)
|
||||
if err != nil {
|
||||
return "", errors.Wrap(err, "failed to parse postgres dsn")
|
||||
}
|
||||
u.Host = net.JoinHostPort(postgresNetworkAlias, "5432")
|
||||
return u.String(), nil
|
||||
case "sqlite":
|
||||
return hostDSN, nil
|
||||
default:
|
||||
return "", errors.Errorf("unsupported driver for container dsn: %s", driver)
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -0,0 +1,274 @@
|
|||
package test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"os"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
storepb "github.com/usememos/memos/proto/gen/store"
|
||||
"github.com/usememos/memos/store"
|
||||
)
|
||||
|
||||
func TestMigrationFromV0262PreservesLegacyData(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("skipping container-based upgrade test in short mode")
|
||||
}
|
||||
if os.Getenv("SKIP_CONTAINER_TESTS") == "1" {
|
||||
t.Skip("skipping container-based test (SKIP_CONTAINER_TESTS=1)")
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
driver := getDriverFromEnv()
|
||||
|
||||
cfg, hostDSN := prepareV0262MigrationTest(t, driver)
|
||||
t.Logf("Starting Memos %s container for %s schema bootstrap...", cfg.Version, driver)
|
||||
container, err := StartMemosContainer(ctx, cfg)
|
||||
require.NoError(t, err, "failed to start v0.26.2 memos container")
|
||||
t.Cleanup(func() {
|
||||
if container != nil {
|
||||
_ = container.Terminate(ctx)
|
||||
}
|
||||
})
|
||||
|
||||
legacyStore := NewTestingStoreWithDSN(ctx, t, driver, hostDSN)
|
||||
require.Eventually(t, func() bool {
|
||||
setting, err := legacyStore.GetInstanceBasicSetting(ctx)
|
||||
return err == nil && setting != nil && setting.SchemaVersion != ""
|
||||
}, 45*time.Second, 500*time.Millisecond, "legacy schema should be initialized by old container")
|
||||
|
||||
settingBeforeSeed, err := legacyStore.GetInstanceBasicSetting(ctx)
|
||||
require.NoError(t, err)
|
||||
t.Logf("Legacy schema version before migration: %s", settingBeforeSeed.SchemaVersion)
|
||||
|
||||
err = container.Terminate(ctx)
|
||||
require.NoError(t, err, "failed to stop v0.26.2 memos container")
|
||||
container = nil
|
||||
|
||||
db := openMigrationSQLDB(t, driver, hostDSN)
|
||||
defer db.Close()
|
||||
|
||||
seedLegacyMigrationData(ctx, t, driver, db)
|
||||
|
||||
count, err := countSystemSetting(ctx, db, "STORAGE")
|
||||
require.NoError(t, err)
|
||||
require.Zero(t, count, "v0.26.2 database should not have a STORAGE setting before migration")
|
||||
|
||||
ts := NewTestingStoreWithDSN(ctx, t, driver, hostDSN)
|
||||
err = ts.Migrate(ctx)
|
||||
require.NoError(t, err, "migration from v0.26.2 should succeed for %s", driver)
|
||||
|
||||
currentVersion, err := ts.GetCurrentSchemaVersion()
|
||||
require.NoError(t, err)
|
||||
currentSetting, err := ts.GetInstanceBasicSetting(ctx)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, currentVersion, currentSetting.SchemaVersion, "schema version should be updated")
|
||||
|
||||
storageSetting, err := ts.GetInstanceStorageSetting(ctx)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, storepb.InstanceStorageSetting_DATABASE, storageSetting.StorageType, "existing installs should stay on DATABASE storage")
|
||||
|
||||
idps, err := ts.ListIdentityProviders(ctx, &store.FindIdentityProvider{})
|
||||
require.NoError(t, err)
|
||||
require.Len(t, idps, 2)
|
||||
idpUIDsByName := map[string]string{}
|
||||
for _, idp := range idps {
|
||||
idpUIDsByName[idp.Name] = idp.Uid
|
||||
}
|
||||
require.Equal(t, "00000191", idpUIDsByName["Legacy Google"])
|
||||
require.Equal(t, "00000192", idpUIDsByName["Legacy GitHub"])
|
||||
|
||||
inboxes, err := ts.ListInboxes(ctx, &store.FindInbox{})
|
||||
require.NoError(t, err)
|
||||
require.Len(t, inboxes, 1)
|
||||
require.NotNil(t, inboxes[0].Message)
|
||||
require.Equal(t, storepb.InboxMessage_MEMO_COMMENT, inboxes[0].Message.Type)
|
||||
require.Equal(t, int32(102), inboxes[0].Message.GetMemoComment().MemoId)
|
||||
require.Equal(t, int32(101), inboxes[0].Message.GetMemoComment().RelatedMemoId)
|
||||
|
||||
activityExists, err := tableExists(ctx, db, driver, "activity")
|
||||
require.NoError(t, err)
|
||||
require.False(t, activityExists, "activity table should be removed after migration")
|
||||
|
||||
memoShareExists, err := tableExists(ctx, db, driver, "memo_share")
|
||||
require.NoError(t, err)
|
||||
require.True(t, memoShareExists, "memo_share table should be created")
|
||||
|
||||
share, err := ts.CreateMemoShare(ctx, &store.MemoShare{
|
||||
UID: "post-upgrade-share",
|
||||
MemoID: 101,
|
||||
CreatorID: 11,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "post-upgrade-share", share.UID)
|
||||
|
||||
postUpgradeUser, err := createTestingUserWithRole(ctx, ts, "postupgrade", store.RoleUser)
|
||||
require.NoError(t, err)
|
||||
postUpgradeMemo, err := ts.CreateMemo(ctx, &store.Memo{
|
||||
UID: "post-upgrade-memo-v0262",
|
||||
CreatorID: postUpgradeUser.ID,
|
||||
Content: "created after v0.26.2 migration",
|
||||
Visibility: store.Public,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "created after v0.26.2 migration", postUpgradeMemo.Content)
|
||||
}
|
||||
|
||||
func prepareV0262MigrationTest(t *testing.T, driver string) (MemosContainerConfig, string) {
|
||||
t.Helper()
|
||||
|
||||
const version = "0.26.2"
|
||||
|
||||
switch driver {
|
||||
case "sqlite":
|
||||
dataDir := t.TempDir()
|
||||
return MemosContainerConfig{
|
||||
Version: version,
|
||||
Driver: driver,
|
||||
DataDir: dataDir,
|
||||
}, fmt.Sprintf("%s/memos_prod.db", dataDir)
|
||||
case "mysql":
|
||||
hostDSN := GetMySQLDSN(t)
|
||||
containerDSN, err := getContainerDSN(driver, hostDSN)
|
||||
require.NoError(t, err)
|
||||
return MemosContainerConfig{
|
||||
Version: version,
|
||||
Driver: driver,
|
||||
DSN: containerDSN,
|
||||
}, hostDSN
|
||||
case "postgres":
|
||||
hostDSN := GetPostgresDSN(t)
|
||||
containerDSN, err := getContainerDSN(driver, hostDSN)
|
||||
require.NoError(t, err)
|
||||
return MemosContainerConfig{
|
||||
Version: version,
|
||||
Driver: driver,
|
||||
DSN: containerDSN,
|
||||
}, hostDSN
|
||||
default:
|
||||
t.Fatalf("unsupported driver: %s", driver)
|
||||
return MemosContainerConfig{}, ""
|
||||
}
|
||||
}
|
||||
|
||||
func openMigrationSQLDB(t *testing.T, driver, dsn string) *sql.DB {
|
||||
t.Helper()
|
||||
|
||||
db, err := sql.Open(driver, dsn)
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, db.Ping())
|
||||
return db
|
||||
}
|
||||
|
||||
func seedLegacyMigrationData(ctx context.Context, t *testing.T, driver string, db *sql.DB) {
|
||||
t.Helper()
|
||||
|
||||
execMigrationSQL(t, db, legacyInsertUserSQL(driver, 11, "owner"))
|
||||
execMigrationSQL(t, db, legacyInsertUserSQL(driver, 12, "commenter"))
|
||||
execMigrationSQL(t, db, legacyInsertMemoSQL(101, 11, "legacy-parent", "parent memo"))
|
||||
execMigrationSQL(t, db, legacyInsertMemoSQL(102, 12, "legacy-comment", "comment memo"))
|
||||
execMigrationSQL(t, db, legacyInsertActivitySQL(201, 12))
|
||||
execMigrationSQL(t, db, legacyInsertInboxSQL(301, 12, 11, 201))
|
||||
execMigrationSQL(t, db, legacyInsertIDPSQL(401, "Legacy Google"))
|
||||
execMigrationSQL(t, db, legacyInsertIDPSQL(402, "Legacy GitHub"))
|
||||
|
||||
var message string
|
||||
err := db.QueryRowContext(ctx, "SELECT message FROM inbox WHERE id = 301").Scan(&message)
|
||||
require.NoError(t, err)
|
||||
require.Contains(t, message, "\"activityId\":201")
|
||||
require.NotContains(t, message, "\"memoComment\"")
|
||||
}
|
||||
|
||||
func execMigrationSQL(t *testing.T, db *sql.DB, query string) {
|
||||
t.Helper()
|
||||
_, err := db.Exec(query)
|
||||
require.NoError(t, err, "failed to execute SQL: %s", query)
|
||||
}
|
||||
|
||||
func legacyInsertUserSQL(driver string, id int, username string) string {
|
||||
table := "user"
|
||||
switch driver {
|
||||
case "mysql":
|
||||
table = "`user`"
|
||||
case "postgres", "sqlite":
|
||||
table = `"user"`
|
||||
default:
|
||||
// Keep the unquoted fallback for unknown test drivers.
|
||||
}
|
||||
|
||||
return fmt.Sprintf(
|
||||
"INSERT INTO %s (id, username, role, email, nickname, password_hash, avatar_url, description) VALUES (%d, '%s', 'USER', '%s@example.com', '%s', 'legacy-hash', '', 'legacy user')",
|
||||
table, id, username, username, username,
|
||||
)
|
||||
}
|
||||
|
||||
func legacyInsertMemoSQL(id, creatorID int, uid, content string) string {
|
||||
payload := "{}"
|
||||
return fmt.Sprintf(
|
||||
"INSERT INTO memo (id, uid, creator_id, content, visibility, payload) VALUES (%d, '%s', %d, '%s', 'PRIVATE', '%s')",
|
||||
id, uid, creatorID, content, payload,
|
||||
)
|
||||
}
|
||||
|
||||
func legacyInsertActivitySQL(id, creatorID int) string {
|
||||
payload := `{"memoComment":{"memoId":102,"relatedMemoId":101}}`
|
||||
return fmt.Sprintf(
|
||||
"INSERT INTO activity (id, creator_id, type, level, payload) VALUES (%d, %d, 'MEMO_COMMENT', 'INFO', '%s')",
|
||||
id, creatorID, payload,
|
||||
)
|
||||
}
|
||||
|
||||
func legacyInsertInboxSQL(id, senderID, receiverID, activityID int) string {
|
||||
message := fmt.Sprintf(`{"type":"MEMO_COMMENT","activityId":%d}`, activityID)
|
||||
return fmt.Sprintf(
|
||||
"INSERT INTO inbox (id, sender_id, receiver_id, status, message) VALUES (%d, %d, %d, 'UNREAD', '%s')",
|
||||
id, senderID, receiverID, message,
|
||||
)
|
||||
}
|
||||
|
||||
func legacyInsertIDPSQL(id int, name string) string {
|
||||
config := `{"clientId":"legacy-client","clientSecret":"legacy-secret","authUrl":"https://example.com/auth","tokenUrl":"https://example.com/token","userInfoUrl":"https://example.com/userinfo"}`
|
||||
return fmt.Sprintf(
|
||||
"INSERT INTO idp (id, name, type, identifier_filter, config) VALUES (%d, '%s', 'OAUTH2', '', '%s')",
|
||||
id, name, config,
|
||||
)
|
||||
}
|
||||
|
||||
func countSystemSetting(ctx context.Context, db *sql.DB, name string) (int, error) {
|
||||
var count int
|
||||
err := db.QueryRowContext(ctx, "SELECT COUNT(*) FROM system_setting WHERE name = ?", name).Scan(&count)
|
||||
if err == nil {
|
||||
return count, nil
|
||||
}
|
||||
|
||||
err = db.QueryRowContext(ctx, "SELECT COUNT(*) FROM system_setting WHERE name = $1", name).Scan(&count)
|
||||
return count, err
|
||||
}
|
||||
|
||||
func tableExists(ctx context.Context, db *sql.DB, driver, table string) (bool, error) {
|
||||
switch driver {
|
||||
case "sqlite":
|
||||
var name string
|
||||
err := db.QueryRowContext(ctx, "SELECT name FROM sqlite_master WHERE type = 'table' AND name = ?", table).Scan(&name)
|
||||
if err == sql.ErrNoRows {
|
||||
return false, nil
|
||||
}
|
||||
return err == nil, err
|
||||
case "mysql":
|
||||
var count int
|
||||
err := db.QueryRowContext(ctx, "SELECT COUNT(*) FROM information_schema.tables WHERE table_schema = DATABASE() AND table_name = ?", table).Scan(&count)
|
||||
return count > 0, err
|
||||
case "postgres":
|
||||
var regclass sql.NullString
|
||||
err := db.QueryRowContext(ctx, "SELECT to_regclass($1)", "public."+table).Scan(®class)
|
||||
return regclass.Valid && strings.EqualFold(regclass.String, table), err
|
||||
default:
|
||||
return false, errors.Errorf("unsupported driver: %s", driver)
|
||||
}
|
||||
}
|
||||
Loading…
Reference in New Issue