diff options
author | Patrick Steinhardt <psteinhardt@gitlab.com> | 2020-07-03 08:58:14 +0300 |
---|---|---|
committer | Patrick Steinhardt <psteinhardt@gitlab.com> | 2020-07-03 08:58:14 +0300 |
commit | de246fc4268b1db216a348f01f37f54907b21635 (patch) | |
tree | 900cf600313c3148a5eb4593ecb58bcdfe842c94 | |
parent | 78873d0bfe8d091d60accb813e17c9c5e8a014ec (diff) | |
parent | abbe26db3f3a84f7498919919e3c5561d39ff800 (diff) |
Merge branch 'pks-tx-weighted-voting' into 'master'
Implement weighted voting
See merge request gitlab-org/gitaly!2334
-rw-r--r-- | changelogs/unreleased/pks-tx-weighted-voting.yml | 5 | ||||
-rw-r--r-- | internal/praefect/coordinator.go | 7 | ||||
-rw-r--r-- | internal/praefect/transaction_test.go | 220 | ||||
-rw-r--r-- | internal/praefect/transactions/manager.go | 13 | ||||
-rw-r--r-- | internal/praefect/transactions/transaction.go | 85 |
5 files changed, 291 insertions, 39 deletions
diff --git a/changelogs/unreleased/pks-tx-weighted-voting.yml b/changelogs/unreleased/pks-tx-weighted-voting.yml new file mode 100644 index 000000000..8b4973e86 --- /dev/null +++ b/changelogs/unreleased/pks-tx-weighted-voting.yml @@ -0,0 +1,5 @@ +--- +title: Implement weighted voting +merge_request: 2334 +author: +type: added diff --git a/internal/praefect/coordinator.go b/internal/praefect/coordinator.go index 4a0748512..e94b9a55f 100644 --- a/internal/praefect/coordinator.go +++ b/internal/praefect/coordinator.go @@ -197,14 +197,17 @@ func (c *Coordinator) mutatorStreamParameters(ctx context.Context, call grpcCall if _, ok := transactionRPCs[call.fullMethodName]; ok && featureflag.IsEnabled(ctx, featureflag.ReferenceTransactions) { var voters []transactions.Voter + var threshold uint for _, node := range append(shard.Secondaries, shard.Primary) { voters = append(voters, transactions.Voter{ - Name: node.GetStorage(), + Name: node.GetStorage(), + Votes: 1, }) + threshold += 1 } - transactionID, transactionCleanup, err := c.txMgr.RegisterTransaction(ctx, voters) + transactionID, transactionCleanup, err := c.txMgr.RegisterTransaction(ctx, voters, threshold) if err != nil { return nil, fmt.Errorf("registering transactions: %w", err) } diff --git a/internal/praefect/transaction_test.go b/internal/praefect/transaction_test.go index b93cb6761..3d3b15b0d 100644 --- a/internal/praefect/transaction_test.go +++ b/internal/praefect/transaction_test.go @@ -3,6 +3,7 @@ package praefect import ( "context" "crypto/sha1" + "fmt" "sync" "testing" "time" @@ -70,8 +71,8 @@ func TestTransactionSucceeds(t *testing.T) { client := gitalypb.NewRefTransactionClient(cc) transactionID, cancelTransaction, err := txMgr.RegisterTransaction(ctx, []transactions.Voter{ - {Name: "node1"}, - }) + {Name: "node1", Votes: 1}, + }, 1) require.NoError(t, err) require.NotZero(t, transactionID) defer cancelTransaction() @@ -169,11 +170,13 @@ func TestTransactionWithMultipleNodes(t *testing.T) { for _, tc := range testcases { t.Run(tc.desc, func(t *testing.T) { var voters []transactions.Voter + var threshold uint for _, node := range tc.nodes { - voters = append(voters, transactions.Voter{Name: node}) + voters = append(voters, transactions.Voter{Name: node, Votes: 1}) + threshold += 1 } - transactionID, cancelTransaction, err := txMgr.RegisterTransaction(ctx, voters) + transactionID, cancelTransaction, err := txMgr.RegisterTransaction(ctx, voters, threshold) require.NoError(t, err) defer cancelTransaction() @@ -208,9 +211,9 @@ func TestTransactionWithContextCancellation(t *testing.T) { ctx, cancel := testhelper.Context() transactionID, cancelTransaction, err := txMgr.RegisterTransaction(ctx, []transactions.Voter{ - {Name: "voter"}, - {Name: "absent"}, - }) + {Name: "voter", Votes: 1}, + {Name: "absent", Votes: 1}, + }, 2) require.NoError(t, err) defer cancelTransaction() @@ -240,17 +243,206 @@ func TestTransactionRegistrationWithInvalidNodesFails(t *testing.T) { txMgr := transactions.NewManager() - _, _, err := txMgr.RegisterTransaction(ctx, []transactions.Voter{}) + _, _, err := txMgr.RegisterTransaction(ctx, []transactions.Voter{}, 1) require.Equal(t, transactions.ErrMissingNodes, err) _, _, err = txMgr.RegisterTransaction(ctx, []transactions.Voter{ - {Name: "node1"}, - {Name: "node2"}, - {Name: "node1"}, - }) + {Name: "node1", Votes: 1}, + {Name: "node2", Votes: 1}, + {Name: "node1", Votes: 1}, + }, 3) require.Equal(t, transactions.ErrDuplicateNodes, err) } +func TestTransactionRegistrationWithInvalidThresholdFails(t *testing.T) { + tc := []struct { + desc string + votes []uint + threshold uint + }{ + { + desc: "threshold is unreachable", + votes: []uint{1, 1}, + threshold: 3, + }, + { + desc: "threshold of zero fails", + votes: []uint{0}, + threshold: 0, + }, + { + desc: "threshold smaller than majority fails", + votes: []uint{1, 1, 1}, + threshold: 1, + }, + { + desc: "threshold equaling majority fails", + votes: []uint{1, 1, 1, 1}, + threshold: 2, + }, + { + desc: "threshold accounts for higher node votes", + votes: []uint{2, 2, 2, 2}, + threshold: 4, + }, + } + + ctx, cleanup := testhelper.Context() + defer cleanup() + + txMgr := transactions.NewManager() + + for _, tc := range tc { + t.Run(tc.desc, func(t *testing.T) { + var voters []transactions.Voter + + for i, votes := range tc.votes { + voters = append(voters, transactions.Voter{ + Name: fmt.Sprintf("node-%d", i), + Votes: votes, + }) + } + + _, _, err := txMgr.RegisterTransaction(ctx, voters, tc.threshold) + require.Equal(t, transactions.ErrInvalidThreshold, err) + }) + } +} + +func TestTransactionReachesQuorum(t *testing.T) { + type voter struct { + votes uint + vote string + showsUp bool + shouldSucceed bool + } + + tc := []struct { + desc string + voters []voter + threshold uint + }{ + { + desc: "quorum is is not reached without majority", + voters: []voter{ + {votes: 1, vote: "foo", showsUp: true, shouldSucceed: false}, + {votes: 1, vote: "bar", showsUp: true, shouldSucceed: false}, + {votes: 1, vote: "baz", showsUp: true, shouldSucceed: false}, + }, + threshold: 2, + }, + { + desc: "quorum is reached with unweighted node failing", + voters: []voter{ + {votes: 1, vote: "foo", showsUp: true, shouldSucceed: true}, + {votes: 0, vote: "bar", showsUp: true, shouldSucceed: false}, + }, + threshold: 1, + }, + { + desc: "quorum is reached with majority", + voters: []voter{ + {votes: 1, vote: "foo", showsUp: true, shouldSucceed: true}, + {votes: 1, vote: "foo", showsUp: true, shouldSucceed: true}, + {votes: 1, vote: "bar", showsUp: true, shouldSucceed: false}, + }, + threshold: 2, + }, + { + desc: "quorum is reached with high vote outweighing", + voters: []voter{ + {votes: 3, vote: "foo", showsUp: true, shouldSucceed: true}, + {votes: 1, vote: "bar", showsUp: true, shouldSucceed: false}, + {votes: 1, vote: "bar", showsUp: true, shouldSucceed: false}, + }, + threshold: 3, + }, + { + desc: "quorum is reached with high vote being outweighed", + voters: []voter{ + {votes: 3, vote: "foo", showsUp: true, shouldSucceed: false}, + {votes: 1, vote: "bar", showsUp: true, shouldSucceed: true}, + {votes: 1, vote: "bar", showsUp: true, shouldSucceed: true}, + {votes: 1, vote: "bar", showsUp: true, shouldSucceed: true}, + {votes: 1, vote: "bar", showsUp: true, shouldSucceed: true}, + }, + threshold: 4, + }, + { + desc: "quorum is reached with disappearing unweighted voter", + voters: []voter{ + {votes: 1, vote: "foo", showsUp: true, shouldSucceed: true}, + {votes: 0, vote: "foo", showsUp: false, shouldSucceed: false}, + }, + threshold: 1, + }, + { + desc: "quorum is reached with disappearing weighted voter", + voters: []voter{ + {votes: 1, vote: "foo", showsUp: true, shouldSucceed: true}, + {votes: 1, vote: "foo", showsUp: true, shouldSucceed: true}, + {votes: 1, vote: "bar", showsUp: false, shouldSucceed: false}, + }, + threshold: 2, + }, + } + + cc, txMgr, cleanup := runPraefectServerAndTxMgr(t) + defer cleanup() + + ctx, cleanup := testhelper.Context() + defer cleanup() + + client := gitalypb.NewRefTransactionClient(cc) + + for _, tc := range tc { + t.Run(tc.desc, func(t *testing.T) { + var voters []transactions.Voter + + for i, voter := range tc.voters { + voters = append(voters, transactions.Voter{ + Name: fmt.Sprintf("node-%d", i), + Votes: voter.votes, + }) + } + + transactionID, cancel, err := txMgr.RegisterTransaction(ctx, voters, tc.threshold) + require.NoError(t, err) + defer cancel() + + var wg sync.WaitGroup + for i, v := range tc.voters { + if !v.showsUp { + continue + } + + wg.Add(1) + go func(i int, v voter) { + defer wg.Done() + + name := fmt.Sprintf("node-%d", i) + hash := sha1.Sum([]byte(v.vote)) + + response, err := client.VoteTransaction(ctx, &gitalypb.VoteTransactionRequest{ + TransactionId: transactionID, + Node: name, + ReferenceUpdatesHash: hash[:], + }) + require.NoError(t, err) + + if v.shouldSucceed { + require.Equal(t, gitalypb.VoteTransactionResponse_COMMIT, response.State, "node should have received COMMIT") + } else { + require.Equal(t, gitalypb.VoteTransactionResponse_ABORT, response.State, "node should have received ABORT") + } + }(i, v) + } + + wg.Wait() + }) + } +} + func TestTransactionFailures(t *testing.T) { counter, opts := setupMetrics() cc, _, cleanup := runPraefectServerAndTxMgr(t, opts...) @@ -287,8 +479,8 @@ func TestTransactionCancellation(t *testing.T) { client := gitalypb.NewRefTransactionClient(cc) transactionID, cancelTransaction, err := txMgr.RegisterTransaction(ctx, []transactions.Voter{ - {Name: "node1"}, - }) + {Name: "node1", Votes: 1}, + }, 1) require.NoError(t, err) require.NotZero(t, transactionID) diff --git a/internal/praefect/transactions/manager.go b/internal/praefect/transactions/manager.go index 721e47522..25248caac 100644 --- a/internal/praefect/transactions/manager.go +++ b/internal/praefect/transactions/manager.go @@ -106,8 +106,10 @@ func (mgr *Manager) log(ctx context.Context) logrus.FieldLogger { type CancelFunc func() // RegisterTransaction registers a new reference transaction for a set of nodes -// taking part in the transaction. -func (mgr *Manager) RegisterTransaction(ctx context.Context, voters []Voter) (uint64, CancelFunc, error) { +// taking part in the transaction. `threshold` is the threshold at which an +// election will succeed. It needs to be in the range `weight(voters)/2 < +// threshold <= weight(voters) to avoid indecidable votes. +func (mgr *Manager) RegisterTransaction(ctx context.Context, voters []Voter, threshold uint) (uint64, CancelFunc, error) { mgr.lock.Lock() defer mgr.lock.Unlock() @@ -117,7 +119,7 @@ func (mgr *Manager) RegisterTransaction(ctx context.Context, voters []Voter) (ui // nodes still have in-flight transactions. transactionID := mgr.txIDGenerator.ID() - transaction, err := newTransaction(voters) + transaction, err := newTransaction(voters, threshold) if err != nil { return 0, nil, err } @@ -159,7 +161,7 @@ func (mgr *Manager) voteTransaction(ctx context.Context, transactionID uint64, n return err } - if err := transaction.collectVotes(ctx); err != nil { + if err := transaction.collectVotes(ctx, node); err != nil { return err } @@ -167,8 +169,7 @@ func (mgr *Manager) voteTransaction(ctx context.Context, transactionID uint64, n } // VoteTransaction is called by a client who's casting a vote on a reference -// transaction. It will wait for all clients of a given transaction to start -// the transaction and perform a vote. +// transaction. It waits until quorum was reached on the given transaction. func (mgr *Manager) VoteTransaction(ctx context.Context, transactionID uint64, node string, hash []byte) error { start := time.Now() defer func() { diff --git a/internal/praefect/transactions/transaction.go b/internal/praefect/transactions/transaction.go index a5c09ce95..40bb809c4 100644 --- a/internal/praefect/transactions/transaction.go +++ b/internal/praefect/transactions/transaction.go @@ -11,6 +11,7 @@ import ( var ( ErrDuplicateNodes = errors.New("transactions cannot have duplicate nodes") ErrMissingNodes = errors.New("transaction requires at least one node") + ErrInvalidThreshold = errors.New("transaction has invalid threshold") ErrTransactionVoteFailed = errors.New("transaction vote failed") ErrTransactionCanceled = errors.New("transaction was canceled") ) @@ -19,6 +20,10 @@ var ( type Voter struct { // Name of the voter, usually Gitaly's storage name. Name string + // Votes is the number of votes available to this voter in the voting + // process. `0` means the outcome of the vote will not be influenced by + // this voter. + Votes uint vote vote } @@ -44,16 +49,21 @@ type transaction struct { doneCh chan interface{} cancelCh chan interface{} + threshold uint + lock sync.Mutex votersByNode map[string]*Voter + voteCounts map[vote]uint } -func newTransaction(voters []Voter) (*transaction, error) { +func newTransaction(voters []Voter, threshold uint) (*transaction, error) { if len(voters) == 0 { return nil, ErrMissingNodes } + var totalVotes uint votersByNode := make(map[string]*Voter, len(voters)) + for _, voter := range voters { if _, ok := votersByNode[voter.Name]; ok { return nil, ErrDuplicateNodes @@ -61,12 +71,28 @@ func newTransaction(voters []Voter) (*transaction, error) { voter := voter // rescope loop variable votersByNode[voter.Name] = &voter + totalVotes += voter.Votes + } + + // If the given threshold is smaller than the total votes, then we + // cannot ever reach quorum. + if totalVotes < threshold { + return nil, ErrInvalidThreshold + } + + // If the threshold is less or equal than half of all node's votes, + // it's possible to reach multiple different quorums that settle on + // different outcomes. + if threshold*2 <= totalVotes { + return nil, ErrInvalidThreshold } return &transaction{ doneCh: make(chan interface{}), cancelCh: make(chan interface{}), + threshold: threshold, votersByNode: votersByNode, + voteCounts: make(map[vote]uint, len(votersByNode)), }, nil } @@ -94,23 +120,45 @@ func (t *transaction) vote(node string, hash []byte) error { } voter.vote = vote - // Count votes to see if we're done. If there are no more votes, then - // we must notify other voters (and ourselves) by closing the `done` - // channel. + oldCount := t.voteCounts[vote] + newCount := oldCount + voter.Votes + t.voteCounts[vote] = newCount + + // If the threshold was reached before already, we mustn't try to + // signal the other voters again. + if oldCount >= t.threshold { + return nil + } + + // If we've just crossed the threshold, signal all voters that the + // voting has concluded. + if newCount >= t.threshold { + close(t.doneCh) + return nil + } + + // If any other vote has already reached the threshold, we mustn't try + // to notify voters again. + for _, count := range t.voteCounts { + if count >= t.threshold { + return nil + } + } + + // If any of the voters didn't yet cast its vote, we need to wait for + // them. for _, voter := range t.votersByNode { if voter.vote.isEmpty() { return nil } } - // As only the last voter may see that all participants have cast their - // vote, this can really only be called by a single goroutine. + // Otherwise, signal voters that all votes were gathered. close(t.doneCh) - return nil } -func (t *transaction) collectVotes(ctx context.Context) error { +func (t *transaction) collectVotes(ctx context.Context, node string) error { select { case <-ctx.Done(): return ctx.Err() @@ -120,15 +168,18 @@ func (t *transaction) collectVotes(ctx context.Context) error { break } - // Count votes to see whether we reached agreement or not. There should - // be no need to lock as nobody will modify the votes anymore. - var firstVote vote - for _, voter := range t.votersByNode { - if firstVote.isEmpty() { - firstVote = voter.vote - } else if firstVote != voter.vote { - return ErrTransactionVoteFailed - } + t.lock.Lock() + defer t.lock.Unlock() + + voter, ok := t.votersByNode[node] + if !ok { + return fmt.Errorf("invalid node for transaction: %q", node) + } + + // See if our vote crossed the threshold. As there can be only one vote + // exceeding it, we know we're the winner in that case. + if t.voteCounts[voter.vote] < t.threshold { + return ErrTransactionVoteFailed } return nil |