mirror of https://github.com/usememos/memos.git
refactor: schema migrator
This commit is contained in:
parent
d386b83b7b
commit
3fd29f6493
|
|
@ -1,6 +1,7 @@
|
|||
package mysql
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
|
||||
"github.com/go-sql-driver/mysql"
|
||||
|
|
@ -47,6 +48,15 @@ func (d *DB) Close() error {
|
|||
return d.db.Close()
|
||||
}
|
||||
|
||||
func (d *DB) IsInitialized(ctx context.Context) (bool, error) {
|
||||
var exists bool
|
||||
err := d.db.QueryRow("SELECT EXISTS (SELECT 1 FROM information_schema.tables WHERE TABLE_NAME = 'memo' AND TABLE_TYPE = 'BASE TABLE')").Scan(&exists)
|
||||
if err != nil {
|
||||
return false, errors.Wrap(err, "failed to check if database is initialized")
|
||||
}
|
||||
return exists, nil
|
||||
}
|
||||
|
||||
func mergeDSN(baseDSN string) (string, error) {
|
||||
config, err := mysql.ParseDSN(baseDSN)
|
||||
if err != nil {
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
package postgres
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"log"
|
||||
|
||||
|
|
@ -15,7 +16,6 @@ import (
|
|||
type DB struct {
|
||||
db *sql.DB
|
||||
profile *profile.Profile
|
||||
// Add any other fields as needed
|
||||
}
|
||||
|
||||
func NewDB(profile *profile.Profile) (store.Driver, error) {
|
||||
|
|
@ -46,3 +46,12 @@ func (d *DB) GetDB() *sql.DB {
|
|||
func (d *DB) Close() error {
|
||||
return d.db.Close()
|
||||
}
|
||||
|
||||
func (d *DB) IsInitialized(ctx context.Context) (bool, error) {
|
||||
var exists bool
|
||||
err := d.db.QueryRow("SELECT EXISTS (SELECT 1 FROM information_schema.tables WHERE table_name = 'memo' AND table_type = 'BASE TABLE')").Scan(&exists)
|
||||
if err != nil {
|
||||
return false, errors.Wrap(err, "failed to check if database is initialized")
|
||||
}
|
||||
return exists, nil
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
package sqlite
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
|
|
@ -57,3 +58,13 @@ func (d *DB) GetDB() *sql.DB {
|
|||
func (d *DB) Close() error {
|
||||
return d.db.Close()
|
||||
}
|
||||
|
||||
func (d *DB) IsInitialized(ctx context.Context) (bool, error) {
|
||||
// Check if the database is initialized by checking if the memo table exists.
|
||||
var exists bool
|
||||
err := d.db.QueryRowContext(ctx, "SELECT EXISTS(SELECT 1 FROM sqlite_master WHERE type='table' AND name='memo')").Scan(&exists)
|
||||
if err != nil {
|
||||
return false, errors.Wrap(err, "failed to check if database is initialized")
|
||||
}
|
||||
return exists, nil
|
||||
}
|
||||
|
|
|
|||
|
|
@ -15,6 +15,8 @@ type Driver interface {
|
|||
GetDB() *sql.DB
|
||||
Close() error
|
||||
|
||||
IsInitialized(ctx context.Context) (bool, error)
|
||||
|
||||
// MigrationHistory model related methods.
|
||||
FindMigrationHistoryList(ctx context.Context, find *FindMigrationHistory) ([]*MigrationHistory, error)
|
||||
UpsertMigrationHistory(ctx context.Context, upsert *UpsertMigrationHistory) (*MigrationHistory, error)
|
||||
|
|
|
|||
|
|
@ -40,26 +40,22 @@ func (s *Store) Migrate(ctx context.Context) error {
|
|||
}
|
||||
|
||||
if s.profile.Mode == "prod" {
|
||||
migrationHistoryList, err := s.driver.FindMigrationHistoryList(ctx, &FindMigrationHistory{})
|
||||
workspaceBasicSetting, err := s.GetWorkspaceBasicSetting(ctx)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "failed to find migration history")
|
||||
return errors.Wrap(err, "failed to get workspace basic setting")
|
||||
}
|
||||
if len(migrationHistoryList) == 0 {
|
||||
return errors.Errorf("no migration history found")
|
||||
}
|
||||
|
||||
migrationHistoryVersions := []string{}
|
||||
for _, migrationHistory := range migrationHistoryList {
|
||||
migrationHistoryVersions = append(migrationHistoryVersions, migrationHistory.Version)
|
||||
}
|
||||
sort.Sort(version.SortVersion(migrationHistoryVersions))
|
||||
latestMigrationHistoryVersion := migrationHistoryVersions[len(migrationHistoryVersions)-1]
|
||||
schemaVersion, err := s.GetCurrentSchemaVersion()
|
||||
currentSchemaVersion, err := s.GetCurrentSchemaVersion()
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "failed to get current schema version")
|
||||
}
|
||||
|
||||
if version.IsVersionGreaterThan(schemaVersion, latestMigrationHistoryVersion) {
|
||||
if version.IsVersionGreaterThan(workspaceBasicSetting.SchemaVersion, currentSchemaVersion) {
|
||||
slog.Error("cannot downgrade schema version",
|
||||
slog.String("databaseVersion", workspaceBasicSetting.SchemaVersion),
|
||||
slog.String("currentVersion", currentSchemaVersion),
|
||||
)
|
||||
return errors.Errorf("cannot downgrade schema version from %s to %s", workspaceBasicSetting.SchemaVersion, currentSchemaVersion)
|
||||
}
|
||||
if version.IsVersionGreaterThan(currentSchemaVersion, workspaceBasicSetting.SchemaVersion) {
|
||||
filePaths, err := fs.Glob(migrationFS, fmt.Sprintf("%s*/*.sql", s.getMigrationBasePath()))
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "failed to read migration files")
|
||||
|
|
@ -73,13 +69,13 @@ func (s *Store) Migrate(ctx context.Context) error {
|
|||
}
|
||||
defer tx.Rollback()
|
||||
|
||||
slog.Info("start migration", slog.String("currentSchemaVersion", latestMigrationHistoryVersion), slog.String("targetSchemaVersion", schemaVersion))
|
||||
slog.Info("start migration", slog.String("currentSchemaVersion", workspaceBasicSetting.SchemaVersion), slog.String("targetSchemaVersion", currentSchemaVersion))
|
||||
for _, filePath := range filePaths {
|
||||
fileSchemaVersion, err := s.getSchemaVersionOfMigrateScript(filePath)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "failed to get schema version of migrate script")
|
||||
}
|
||||
if version.IsVersionGreaterThan(fileSchemaVersion, latestMigrationHistoryVersion) && version.IsVersionGreaterOrEqualThan(schemaVersion, fileSchemaVersion) {
|
||||
if version.IsVersionGreaterThan(fileSchemaVersion, workspaceBasicSetting.SchemaVersion) && version.IsVersionGreaterOrEqualThan(currentSchemaVersion, fileSchemaVersion) {
|
||||
bytes, err := migrationFS.ReadFile(filePath)
|
||||
if err != nil {
|
||||
return errors.Wrapf(err, "failed to read minor version migration file: %s", filePath)
|
||||
|
|
@ -90,20 +86,11 @@ func (s *Store) Migrate(ctx context.Context) error {
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
if err := tx.Commit(); err != nil {
|
||||
return errors.Wrap(err, "failed to commit transaction")
|
||||
}
|
||||
slog.Info("end migrate")
|
||||
|
||||
// Upsert the current schema version to migration_history.
|
||||
// TODO: retire using migration history later.
|
||||
if _, err = s.driver.UpsertMigrationHistory(ctx, &UpsertMigrationHistory{
|
||||
Version: schemaVersion,
|
||||
}); err != nil {
|
||||
return errors.Wrapf(err, "failed to upsert migration history with version: %s", schemaVersion)
|
||||
}
|
||||
if err := s.updateCurrentSchemaVersion(ctx, schemaVersion); err != nil {
|
||||
if err := s.updateCurrentSchemaVersion(ctx, currentSchemaVersion); err != nil {
|
||||
return errors.Wrap(err, "failed to update current schema version")
|
||||
}
|
||||
}
|
||||
|
|
@ -117,23 +104,17 @@ func (s *Store) Migrate(ctx context.Context) error {
|
|||
}
|
||||
|
||||
func (s *Store) preMigrate(ctx context.Context) error {
|
||||
// TODO: using schema version in basic setting instead of migration history.
|
||||
migrationHistoryList, err := s.driver.FindMigrationHistoryList(ctx, &FindMigrationHistory{})
|
||||
// If any error occurs or no migration history found, apply the latest schema.
|
||||
if err != nil || len(migrationHistoryList) == 0 {
|
||||
if err != nil {
|
||||
slog.Warn("failed to find migration history in pre-migrate", slog.String("error", err.Error()))
|
||||
}
|
||||
initialized, err := s.driver.IsInitialized(ctx)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "failed to check if database is initialized")
|
||||
}
|
||||
|
||||
if !initialized {
|
||||
filePath := s.getMigrationBasePath() + LatestSchemaFileName
|
||||
bytes, err := migrationFS.ReadFile(filePath)
|
||||
if err != nil {
|
||||
return errors.Errorf("failed to read latest schema file: %s", err)
|
||||
}
|
||||
schemaVersion, err := s.GetCurrentSchemaVersion()
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "failed to get current schema version")
|
||||
}
|
||||
|
||||
// Start a transaction to apply the latest schema.
|
||||
tx, err := s.driver.GetDB().Begin()
|
||||
if err != nil {
|
||||
|
|
@ -147,20 +128,23 @@ func (s *Store) preMigrate(ctx context.Context) error {
|
|||
return errors.Wrap(err, "failed to commit transaction")
|
||||
}
|
||||
|
||||
// TODO: using schema version in basic setting instead of migration history.
|
||||
if _, err := s.driver.UpsertMigrationHistory(ctx, &UpsertMigrationHistory{
|
||||
Version: schemaVersion,
|
||||
}); err != nil {
|
||||
return errors.Wrap(err, "failed to upsert migration history")
|
||||
// Upsert current schema version to database.
|
||||
schemaVersion, err := s.GetCurrentSchemaVersion()
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "failed to get current schema version")
|
||||
}
|
||||
if err := s.updateCurrentSchemaVersion(ctx, schemaVersion); err != nil {
|
||||
return errors.Wrap(err, "failed to update current schema version")
|
||||
}
|
||||
}
|
||||
|
||||
if s.profile.Mode == "prod" {
|
||||
if err := s.normalizedMigrationHistoryList(ctx); err != nil {
|
||||
if err := s.normalizeMigrationHistoryList(ctx); err != nil {
|
||||
return errors.Wrap(err, "failed to normalize migration history list")
|
||||
}
|
||||
if err := s.migrateSchemaVersionToSetting(ctx); err != nil {
|
||||
return errors.Wrap(err, "failed to migrate schema version to setting")
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
|
@ -249,7 +233,22 @@ func (*Store) execute(ctx context.Context, tx *sql.Tx, stmt string) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
func (s *Store) normalizedMigrationHistoryList(ctx context.Context) error {
|
||||
func (s *Store) updateCurrentSchemaVersion(ctx context.Context, schemaVersion string) error {
|
||||
workspaceBasicSetting, err := s.GetWorkspaceBasicSetting(ctx)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "failed to get workspace basic setting")
|
||||
}
|
||||
workspaceBasicSetting.SchemaVersion = schemaVersion
|
||||
if _, err := s.UpsertWorkspaceSetting(ctx, &storepb.WorkspaceSetting{
|
||||
Key: storepb.WorkspaceSettingKey_BASIC,
|
||||
Value: &storepb.WorkspaceSetting_BasicSetting{BasicSetting: workspaceBasicSetting},
|
||||
}); err != nil {
|
||||
return errors.Wrap(err, "failed to upsert workspace setting")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Store) normalizeMigrationHistoryList(ctx context.Context) error {
|
||||
migrationHistoryList, err := s.driver.FindMigrationHistoryList(ctx, &FindMigrationHistory{})
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "failed to find migration history")
|
||||
|
|
@ -258,6 +257,9 @@ func (s *Store) normalizedMigrationHistoryList(ctx context.Context) error {
|
|||
for _, migrationHistory := range migrationHistoryList {
|
||||
versions = append(versions, migrationHistory.Version)
|
||||
}
|
||||
if len(versions) == 0 {
|
||||
return errors.Errorf("no migration history found")
|
||||
}
|
||||
sort.Sort(version.SortVersion(versions))
|
||||
latestVersion := versions[len(versions)-1]
|
||||
latestMinorVersion := version.GetMinorVersion(latestVersion)
|
||||
|
|
@ -289,30 +291,37 @@ func (s *Store) normalizedMigrationHistoryList(ctx context.Context) error {
|
|||
if version.IsVersionGreaterOrEqualThan(latestVersion, latestSchemaVersion) {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Start a transaction to insert the latest schema version to migration_history.
|
||||
tx, err := s.driver.GetDB().Begin()
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "failed to start transaction")
|
||||
if _, err := s.driver.UpsertMigrationHistory(ctx, &UpsertMigrationHistory{
|
||||
Version: latestSchemaVersion,
|
||||
}); err != nil {
|
||||
return errors.Wrap(err, "failed to upsert latest migration history")
|
||||
}
|
||||
defer tx.Rollback()
|
||||
if err := s.execute(ctx, tx, fmt.Sprintf("INSERT INTO migration_history (version) VALUES ('%s')", latestSchemaVersion)); err != nil {
|
||||
return errors.Wrap(err, "failed to insert migration history")
|
||||
}
|
||||
return tx.Commit()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Store) updateCurrentSchemaVersion(ctx context.Context, schemaVersion string) error {
|
||||
func (s *Store) migrateSchemaVersionToSetting(ctx context.Context) error {
|
||||
migrationHistoryList, err := s.driver.FindMigrationHistoryList(ctx, &FindMigrationHistory{})
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "failed to find migration history")
|
||||
}
|
||||
versions := []string{}
|
||||
for _, migrationHistory := range migrationHistoryList {
|
||||
versions = append(versions, migrationHistory.Version)
|
||||
}
|
||||
if len(versions) == 0 {
|
||||
return errors.Errorf("no migration history found")
|
||||
}
|
||||
sort.Sort(version.SortVersion(versions))
|
||||
latestVersion := versions[len(versions)-1]
|
||||
|
||||
workspaceBasicSetting, err := s.GetWorkspaceBasicSetting(ctx)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "failed to get workspace basic setting")
|
||||
}
|
||||
workspaceBasicSetting.SchemaVersion = schemaVersion
|
||||
if _, err := s.UpsertWorkspaceSetting(ctx, &storepb.WorkspaceSetting{
|
||||
Key: storepb.WorkspaceSettingKey_BASIC,
|
||||
Value: &storepb.WorkspaceSetting_BasicSetting{BasicSetting: workspaceBasicSetting},
|
||||
}); err != nil {
|
||||
return errors.Wrap(err, "failed to upsert workspace setting")
|
||||
if version.IsVersionGreaterOrEqualThan(workspaceBasicSetting.SchemaVersion, latestVersion) {
|
||||
if err := s.updateCurrentSchemaVersion(ctx, latestVersion); err != nil {
|
||||
return errors.Wrap(err, "failed to update current schema version")
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in New Issue