diff --git a/store/test/containers.go b/store/test/containers.go index f436782a5..cdb65a1bc 100644 --- a/store/test/containers.go +++ b/store/test/containers.go @@ -42,17 +42,17 @@ var ( wait.ForListeningPort("5230/tcp"), ).WithDeadline(180 * time.Second) - mysqlContainer *mysql.MySQLContainer - postgresContainer *postgres.PostgresContainer + mysqlContainer atomic.Pointer[mysql.MySQLContainer] + postgresContainer atomic.Pointer[postgres.PostgresContainer] mysqlOnce sync.Once postgresOnce sync.Once - mysqlBaseDSN string - postgresBaseDSN string + mysqlBaseDSN atomic.Value // stores string + postgresBaseDSN atomic.Value // stores string dbCounter atomic.Int64 dbCreationMutex sync.Mutex // Protects database creation operations // Network for container communication. - testDockerNetwork *testcontainers.DockerNetwork + testDockerNetwork atomic.Pointer[testcontainers.DockerNetwork] testNetworkOnce sync.Once ) @@ -65,9 +65,9 @@ func getTestNetwork(ctx context.Context) (*testcontainers.DockerNetwork, error) networkErr = err return } - testDockerNetwork = nw + testDockerNetwork.Store(nw) }) - return testDockerNetwork, networkErr + return testDockerNetwork.Load(), networkErr } // GetMySQLDSN starts a MySQL container (if not already running) and creates a fresh database for this test. @@ -99,7 +99,7 @@ func GetMySQLDSN(t *testing.T) string { if err != nil { t.Fatalf("failed to start MySQL container: %v", err) } - mysqlContainer = container + mysqlContainer.Store(container) dsn, err := container.ConnectionString(ctx, "multiStatements=true") if err != nil { @@ -110,10 +110,11 @@ func GetMySQLDSN(t *testing.T) string { t.Fatalf("MySQL not ready for connections: %v", err) } - mysqlBaseDSN = dsn + mysqlBaseDSN.Store(dsn) }) - if mysqlBaseDSN == "" { + dsn, ok := mysqlBaseDSN.Load().(string) + if !ok || dsn == "" { t.Fatal("MySQL container failed to start in a previous test") } @@ -123,7 +124,7 @@ func GetMySQLDSN(t *testing.T) string { // Create a fresh database for this test dbName := fmt.Sprintf("memos_test_%d", dbCounter.Add(1)) - db, err := sql.Open("mysql", mysqlBaseDSN) + db, err := sql.Open("mysql", dsn) if err != nil { t.Fatalf("failed to connect to MySQL: %v", err) } @@ -134,7 +135,7 @@ func GetMySQLDSN(t *testing.T) string { } // Return DSN pointing to the new database - return strings.Replace(mysqlBaseDSN, "/init_db?", "/"+dbName+"?", 1) + return strings.Replace(dsn, "/init_db?", "/"+dbName+"?", 1) } // waitForDB polls the database until it's ready or timeout is reached. @@ -195,7 +196,7 @@ func GetPostgresDSN(t *testing.T) string { if err != nil { t.Fatalf("failed to start PostgreSQL container: %v", err) } - postgresContainer = container + postgresContainer.Store(container) dsn, err := container.ConnectionString(ctx, "sslmode=disable") if err != nil { @@ -206,10 +207,11 @@ func GetPostgresDSN(t *testing.T) string { t.Fatalf("PostgreSQL not ready for connections: %v", err) } - postgresBaseDSN = dsn + postgresBaseDSN.Store(dsn) }) - if postgresBaseDSN == "" { + dsn, ok := postgresBaseDSN.Load().(string) + if !ok || dsn == "" { t.Fatal("PostgreSQL container failed to start in a previous test") } @@ -219,7 +221,7 @@ func GetPostgresDSN(t *testing.T) string { // Create a fresh database for this test dbName := fmt.Sprintf("memos_test_%d", dbCounter.Add(1)) - db, err := sql.Open("postgres", postgresBaseDSN) + db, err := sql.Open("postgres", dsn) if err != nil { t.Fatalf("failed to connect to PostgreSQL: %v", err) } @@ -230,7 +232,7 @@ func GetPostgresDSN(t *testing.T) string { } // Return DSN pointing to the new database - return strings.Replace(postgresBaseDSN, "/init_db?", "/"+dbName+"?", 1) + return strings.Replace(dsn, "/init_db?", "/"+dbName+"?", 1) } // GetDedicatedMySQLDSN starts a dedicated MySQL container for migration testing. @@ -336,33 +338,35 @@ func GetDedicatedPostgresDSN(t *testing.T) (dsn string, containerHost string, cl // This is typically called from TestMain. func TerminateContainers() { ctx := context.Background() - if mysqlContainer != nil { - _ = mysqlContainer.Terminate(ctx) + if container := mysqlContainer.Load(); container != nil { + _ = container.Terminate(ctx) } - if postgresContainer != nil { - _ = postgresContainer.Terminate(ctx) + if container := postgresContainer.Load(); container != nil { + _ = container.Terminate(ctx) } - if testDockerNetwork != nil { - _ = testDockerNetwork.Remove(ctx) + if network := testDockerNetwork.Load(); network != nil { + _ = network.Remove(ctx) } } // GetMySQLContainerHost returns the MySQL container hostname for use within the Docker network. func GetMySQLContainerHost() string { - if mysqlContainer == nil { + container := mysqlContainer.Load() + if container == nil { return "" } - name, _ := mysqlContainer.Name(context.Background()) + name, _ := container.Name(context.Background()) // Remove leading slash from container name return strings.TrimPrefix(name, "/") } // GetPostgresContainerHost returns the PostgreSQL container hostname for use within the Docker network. func GetPostgresContainerHost() string { - if postgresContainer == nil { + container := postgresContainer.Load() + if container == nil { return "" } - name, _ := postgresContainer.Name(context.Background()) + name, _ := container.Name(context.Background()) return strings.TrimPrefix(name, "/") } @@ -395,11 +399,11 @@ func StartMemosContainer(ctx context.Context, cfg MemosContainerConfig) (testcon case "mysql": env["MEMOS_DRIVER"] = "mysql" env["MEMOS_DSN"] = cfg.DSN - opts = append(opts, network.WithNetwork(nil, testDockerNetwork)) + opts = append(opts, network.WithNetwork(nil, testDockerNetwork.Load())) case "postgres": env["MEMOS_DRIVER"] = "postgres" env["MEMOS_DSN"] = cfg.DSN - opts = append(opts, network.WithNetwork(nil, testDockerNetwork)) + opts = append(opts, network.WithNetwork(nil, testDockerNetwork.Load())) default: return nil, errors.Errorf("unsupported driver: %s", cfg.Driver) } diff --git a/store/test/main_test.go b/store/test/main_test.go index 5c4e19274..6890d501f 100644 --- a/store/test/main_test.go +++ b/store/test/main_test.go @@ -13,8 +13,7 @@ func TestMain(m *testing.M) { // If DRIVER is set, run tests for that driver only if os.Getenv("DRIVER") != "" { defer TerminateContainers() - m.Run() - return + os.Exit(m.Run()) } // No DRIVER set - run tests for all drivers sequentially