diff options
author | Patrick Steinhardt <psteinhardt@gitlab.com> | 2021-10-11 10:49:31 +0300 |
---|---|---|
committer | Patrick Steinhardt <psteinhardt@gitlab.com> | 2021-10-11 12:22:11 +0300 |
commit | 0f290a3d34314cb575d50c1e6c7361ce804a654a (patch) | |
tree | a9c99fbc7049d6353d3a6d0d9599f6bb4b6e6b1f | |
parent | 234423dea2ebafabb1924f36eaf841439e2fa01d (diff) |
migrations: Fix data race due to accessing global state
We have recently hit a data race in our migration tests due to two
different Goroutines accessing the global `IgnoreUnknown` flag in the
sql-migrations package. As it turns out, sql-migrations provides two
different APIs: a "simple" API which uses a global `MigrationSet` struct
to keep data in. And an explicit API where the caller needs to create
the `MigrationSet` by himself. Given that we use the former, it's clear
that we easily run into races as soon as two concurrent tests want to do
migrations.
Fix the data race by converting all callsites to use an explicit
`MigrationSet` structure without any global state.
-rw-r--r-- | internal/praefect/datastore/glsql/postgres.go | 13 | ||||
-rw-r--r-- | internal/praefect/datastore/migrations/migrations.go | 7 | ||||
-rw-r--r-- | internal/praefect/datastore/postgres.go | 18 |
3 files changed, 27 insertions, 11 deletions
diff --git a/internal/praefect/datastore/glsql/postgres.go b/internal/praefect/datastore/glsql/postgres.go index f9f3eb951..4b5efcd86 100644 --- a/internal/praefect/datastore/glsql/postgres.go +++ b/internal/praefect/datastore/glsql/postgres.go @@ -29,9 +29,16 @@ func OpenDB(conf config.DB) (*sql.DB, error) { // Migrate will apply all pending SQL migrations. func Migrate(db *sql.DB, ignoreUnknown bool) (int, error) { - migrationSource := &migrate.MemoryMigrationSource{Migrations: migrations.All()} - migrate.SetIgnoreUnknown(ignoreUnknown) - return migrate.Exec(db, "postgres", migrationSource, migrate.Up) + migrationSet := migrate.MigrationSet{ + IgnoreUnknown: ignoreUnknown, + TableName: migrations.MigrationTableName, + } + + migrationSource := &migrate.MemoryMigrationSource{ + Migrations: migrations.All(), + } + + return migrationSet.Exec(db, "postgres", migrationSource, migrate.Up) } // Querier is an abstraction on *sql.DB and *sql.Tx that allows to use their methods without awareness about actual type. diff --git a/internal/praefect/datastore/migrations/migrations.go b/internal/praefect/datastore/migrations/migrations.go index 1d5c0c0a2..3b0ff2704 100644 --- a/internal/praefect/datastore/migrations/migrations.go +++ b/internal/praefect/datastore/migrations/migrations.go @@ -4,14 +4,11 @@ import ( migrate "github.com/rubenv/sql-migrate" ) -const migrationTableName = "schema_migrations" +// MigrationTableName is the name of the SQL table used to store migration info. +const MigrationTableName = "schema_migrations" var allMigrations []*migrate.Migration -func init() { - migrate.SetTable(migrationTableName) -} - // All returns all migrations defined in the package func All() []*migrate.Migration { return allMigrations diff --git a/internal/praefect/datastore/postgres.go b/internal/praefect/datastore/postgres.go index 608abc9db..5d429ab97 100644 --- a/internal/praefect/datastore/postgres.go +++ b/internal/praefect/datastore/postgres.go @@ -50,7 +50,11 @@ func MigrateDownPlan(conf config.Config, max int) ([]string, error) { } defer db.Close() - planned, _, err := migrate.PlanMigration(db, sqlMigrateDialect, migrationSource(), migrate.Down, max) + migrationSet := migrate.MigrationSet{ + TableName: migrations.MigrationTableName, + } + + planned, _, err := migrationSet.PlanMigration(db, sqlMigrateDialect, migrationSource(), migrate.Down, max) if err != nil { return nil, err } @@ -71,7 +75,11 @@ func MigrateDown(conf config.Config, max int) (int, error) { } defer db.Close() - return migrate.ExecMax(db, sqlMigrateDialect, migrationSource(), migrate.Down, max) + migrationSet := migrate.MigrationSet{ + TableName: migrations.MigrationTableName, + } + + return migrationSet.ExecMax(db, sqlMigrateDialect, migrationSource(), migrate.Down, max) } // MigrateStatus returns the status of database migrations. The key of the map @@ -83,12 +91,16 @@ func MigrateStatus(conf config.Config) (map[string]*MigrationStatusRow, error) { } defer db.Close() + migrationSet := migrate.MigrationSet{ + TableName: migrations.MigrationTableName, + } + migrations, err := migrationSource().FindMigrations() if err != nil { return nil, err } - records, err := migrate.GetMigrationRecords(db, sqlMigrateDialect) + records, err := migrationSet.GetMigrationRecords(db, sqlMigrateDialect) if err != nil { return nil, err } |