Welcome to mirror list, hosted at ThFree Co, Russian Federation.

postgres.go « glsql « datastore « praefect « internal - gitlab.com/gitlab-org/gitaly.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
blob: 3bb13621278992feafb97e006f7e01e75f5878ae (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
// Package glsql (Gitaly SQL) is a helper package to work with plain SQL queries.
package glsql

import (
	"context"
	"database/sql"
	"errors"
	"fmt"
	"strings"

	"github.com/jackc/pgconn"
	"github.com/jackc/pgtype"
	"github.com/jackc/pgx/v4"
	"github.com/jackc/pgx/v4/stdlib"
	migrate "github.com/rubenv/sql-migrate"
	"gitlab.com/gitlab-org/gitaly/v15/internal/praefect/config"
	"gitlab.com/gitlab-org/gitaly/v15/internal/praefect/datastore/migrations"
)

// OpenDB returns connection pool to the database.
func OpenDB(ctx context.Context, conf config.DB) (*sql.DB, error) {
	connConfig, err := pgx.ParseConfig(DSN(conf, false))
	if err != nil {
		return nil, err
	}
	connStr := stdlib.RegisterConnConfig(connConfig)
	db, err := sql.Open("pgx", connStr)
	if err != nil {
		return nil, err
	}

	if err := db.PingContext(ctx); err != nil {
		_ = db.Close()
		return nil, fmt.Errorf("send ping: %w", err)
	}

	return db, nil
}

// DSN compiles configuration into data source name with lib/pq specifics.
func DSN(db config.DB, direct bool) string {
	var hostVal, userVal, passwordVal, dbNameVal string
	var sslModeVal, sslCertVal, sslKeyVal, sslRootCertVal string
	var portVal int

	if direct {
		hostVal = coalesceStr(db.SessionPooled.Host, db.HostNoProxy, db.Host)
		portVal = coalesceInt(db.SessionPooled.Port, db.PortNoProxy, db.Port)
		userVal = coalesceStr(db.SessionPooled.User, db.User)
		passwordVal = coalesceStr(db.SessionPooled.Password, db.Password)
		dbNameVal = coalesceStr(db.SessionPooled.DBName, db.DBName)
		sslModeVal = coalesceStr(db.SessionPooled.SSLMode, db.SSLMode)
		sslCertVal = coalesceStr(db.SessionPooled.SSLCert, db.SSLCert)
		sslKeyVal = coalesceStr(db.SessionPooled.SSLKey, db.SSLKey)
		sslRootCertVal = coalesceStr(db.SessionPooled.SSLRootCert, db.SSLRootCert)
	} else {
		hostVal = db.Host
		portVal = db.Port
		userVal = db.User
		passwordVal = db.Password
		dbNameVal = db.DBName
		sslModeVal = db.SSLMode
		sslCertVal = db.SSLCert
		sslKeyVal = db.SSLKey
		sslRootCertVal = db.SSLRootCert
	}

	var fields []string
	if portVal > 0 {
		fields = append(fields, fmt.Sprintf("port=%d", portVal))
	}

	for _, kv := range []struct{ key, value string }{
		{"host", hostVal},
		{"user", userVal},
		{"password", passwordVal},
		{"dbname", dbNameVal},
		{"sslmode", sslModeVal},
		{"sslcert", sslCertVal},
		{"sslkey", sslKeyVal},
		{"sslrootcert", sslRootCertVal},
		{"prefer_simple_protocol", "true"},
	} {
		if len(kv.value) == 0 {
			continue
		}

		kv.value = strings.ReplaceAll(kv.value, "'", `\'`)
		kv.value = strings.ReplaceAll(kv.value, " ", `\ `)

		fields = append(fields, kv.key+"="+kv.value)
	}

	return strings.Join(fields, " ")
}

// Migrate will apply all pending SQL migrations.
func Migrate(db *sql.DB, ignoreUnknown bool) (int, error) {
	migrationSet := migrate.MigrationSet{
		IgnoreUnknown: ignoreUnknown,
		TableName:     migrations.MigrationTableName,
	}

	migrationSource := &migrate.MemoryMigrationSource{
		Migrations: migrations.All(),
	}

	return migrationSet.Exec(db, "postgres", migrationSource, migrate.Up)
}

// MigrateSome will apply migration m and all unapplied migrations with earlier ids.
// To ensure a single migration is executed, run sql-migrate.PlanMigration and call
// MigrateSome for each migration returned.
func MigrateSome(m *migrate.Migration, db *sql.DB, ignoreUnknown bool) (int, error) {
	migrationSet := migrate.MigrationSet{
		IgnoreUnknown: ignoreUnknown,
		TableName:     migrations.MigrationTableName,
	}

	// sql-migrate.ToApply() expects all migrations prior to the final migration be present in the
	// in the slice. If we pass in only the target migration it will not be executed.
	migs := leadingMigrations(m)

	migrationSource := &migrate.MemoryMigrationSource{
		Migrations: migs,
	}

	return migrationSet.Exec(db, "postgres", migrationSource, migrate.Up)
}

// Create a slice of all migrations up to and including the one to be applied.
func leadingMigrations(target *migrate.Migration) []*migrate.Migration {
	allMigrations := migrations.All()

	for i, m := range allMigrations {
		if m.Id == target.Id {
			return allMigrations[:i+1]
		}
	}

	// Planned migration not found in migrations.All(), assume it is more recent
	// and return all migrations.
	return allMigrations
}

// Querier is an abstraction on *sql.DB and *sql.Tx that allows to use their methods without awareness about actual type.
type Querier interface {
	QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error)
	QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row
	ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error)
}

// Notification represent a notification from the database.
type Notification struct {
	// Channel is a name of the receiving channel.
	Channel string
	// Payload is a payload of the notification.
	Payload string
}

// ListenHandler contains a set of methods that would be called on corresponding notifications received.
type ListenHandler interface {
	// Notification would be triggered once a new notification received.
	Notification(Notification)
	// Disconnect would be triggered once a connection to remote service is lost.
	// Passed in error will never be nil and will contain cause of the disconnection.
	Disconnect(error)
	// Connected would be triggered once a connection to remote service is established.
	Connected()
}

// DestProvider returns list of pointers that will be used to scan values into.
type DestProvider interface {
	// To returns list of pointers.
	// It is not an idempotent operation and each call will return a new list.
	To() []interface{}
}

// ScanAll reads all data from 'rows' into holders provided by 'in'.
func ScanAll(rows *sql.Rows, in DestProvider) (err error) {
	for rows.Next() {
		if err = rows.Scan(in.To()...); err != nil {
			return err
		}
	}

	return nil
}

// Uint64Provider allows to use it with ScanAll function to read all rows into it and return result as a slice.
type Uint64Provider []*uint64

// Values returns list of values read from *sql.Rows
func (p *Uint64Provider) Values() []uint64 {
	if len(*p) == 0 {
		return nil
	}

	r := make([]uint64, len(*p))
	for i, v := range *p {
		r[i] = *v
	}
	return r
}

// To returns a list of pointers that will be used as a destination for scan operation.
func (p *Uint64Provider) To() []interface{} {
	var d uint64
	*p = append(*p, &d)
	return []interface{}{&d}
}

// StringProvider allows ScanAll to read all rows and return the result as a slice.
type StringProvider []*string

// Values returns list of values read from *sql.Rows
func (p *StringProvider) Values() []string {
	if len(*p) == 0 {
		return nil
	}

	r := make([]string, len(*p))
	for i, v := range *p {
		r[i] = *v
	}
	return r
}

// To returns a list of pointers that will be used as a destination for scan operation.
func (p *StringProvider) To() []interface{} {
	var d string
	*p = append(*p, &d)
	return []interface{}{&d}
}

func coalesceStr(values ...string) string {
	for _, cur := range values {
		if cur != "" {
			return cur
		}
	}
	return ""
}

func coalesceInt(values ...int) int {
	for _, cur := range values {
		if cur != 0 {
			return cur
		}
	}
	return 0
}

// StringArray is a wrapper that provides a helper methods.
type StringArray struct {
	pgtype.TextArray
}

// Slice converts StringArray into a slice of strings.
// The array element considered to be a valid string if it is not a null.
func (sa StringArray) Slice() []string {
	if sa.Status != pgtype.Present {
		return nil
	}

	res := make([]string, 0, len(sa.Elements))
	if sa.Status == pgtype.Present {
		for _, e := range sa.Elements {
			if e.Status != pgtype.Present {
				continue
			}
			res = append(res, e.String)
		}
	}
	return res
}

// errorCondition is a checker of the additional conditions of an error.
type errorCondition func(*pgconn.PgError) bool

// withConstraintName returns errorCondition that check if constraint name matches provided name.
func withConstraintName(name string) errorCondition {
	return func(pgErr *pgconn.PgError) bool {
		return pgErr.ConstraintName == name
	}
}

// IsQueryCancelled returns true if an error is a query cancellation.
func IsQueryCancelled(err error) bool {
	// https://www.postgresql.org/docs/11/errcodes-appendix.html
	// query_canceled
	return isPgError(err, "57014", nil)
}

// IsUniqueViolation returns true if an error is a unique violation.
func IsUniqueViolation(err error, constraint string) bool {
	// https://www.postgresql.org/docs/11/errcodes-appendix.html
	// unique_violation
	return isPgError(err, "23505", []errorCondition{withConstraintName(constraint)})
}

func isPgError(err error, code string, conditions []errorCondition) bool {
	var pgErr *pgconn.PgError
	if errors.As(err, &pgErr) && pgErr.Code == code {
		for _, condition := range conditions {
			if !condition(pgErr) {
				return false
			}
		}
		return true
	}
	return false
}