Skip to content

Commit

Permalink
Add RAFT tests
Browse files Browse the repository at this point in the history
  • Loading branch information
mprimi committed Oct 4, 2023
1 parent 4e414f1 commit 76267e5
Show file tree
Hide file tree
Showing 2 changed files with 682 additions and 0 deletions.
279 changes: 279 additions & 0 deletions server/raft_helpers_test.go
Expand Up @@ -17,8 +17,12 @@
package server

import (
"encoding"
"encoding/binary"
"encoding/json"
"fmt"
"hash"
"hash/crc32"
"math/rand"
"sync"
"testing"
Expand Down Expand Up @@ -274,3 +278,278 @@ func (rg smGroup) waitOnTotal(t *testing.T, expected int64) {
func newStateAdder(s *Server, cfg *RaftConfig, n RaftNode) stateMachine {
return &stateAdder{s: s, n: n, cfg: cfg}
}

var RaftChainOptions = struct {
verbose bool
maxBlockSize int
}{
false,
25,
}

// Simple implementation of a replicated state machine on top of RAFT.
// Hash each value delivered on top of existing hash
// All replicas should go through the same sequence of block hashes
type raftChainStateMachine struct {
sync.Mutex
s *Server
n RaftNode
cfg *RaftConfig
leader bool
proposalSequence uint64
rng *rand.Rand
hash hash.Hash
blocksApplied uint64
blocksAppliedSinceSnapshot uint64
}

// Block is just a random array of bytes, but contains a little extra metadata to track its source
type ChainBlock struct {
Proposer string
ProposerSequence uint64
Data []byte
}

func (sm *raftChainStateMachine) logDebug(format string, args ...any) {
if RaftChainOptions.verbose {
fmt.Printf("["+sm.s.Name()+" ("+sm.n.ID()+")] "+format+"\n", args...)
}
}

func (sm *raftChainStateMachine) server() *Server {
return sm.s
}

func (sm *raftChainStateMachine) node() RaftNode {
return sm.n
}

func (sm *raftChainStateMachine) propose(data []byte) {
sm.Lock()
defer sm.Unlock()
err := sm.n.ForwardProposal(data)
if err != nil {
sm.logDebug("block proposal error: %s", err)
}
}

func (sm *raftChainStateMachine) applyEntry(ce *CommittedEntry) {
sm.Lock()
defer sm.Unlock()
if ce == nil {
// Nothing to apply
return
}
sm.logDebug("Apply entries #%d (%d entries)", ce.Index, len(ce.Entries))
for _, entry := range ce.Entries {
if entry.Type == EntryNormal {
sm.applyBlock(entry.Data)
} else if entry.Type == EntrySnapshot {
sm.loadSnapshot(entry.Data)
} else {
panic(fmt.Sprintf("[%s] unknown entry type: %s", sm.s.Name(), entry.Type))
}
}
sm.n.Applied(ce.Index)
}

func (sm *raftChainStateMachine) leaderChange(isLeader bool) {
if sm.leader && !isLeader {
sm.logDebug("Leader change: no longer leader")
} else if sm.leader && isLeader {
sm.logDebug("Elected leader while already leader")
} else if !sm.leader && isLeader {
sm.logDebug("Leader change: i am leader")
} else {
sm.logDebug("Leader change")
}
sm.leader = isLeader
}

func (sm *raftChainStateMachine) stop() {
sm.Lock()
defer sm.Unlock()
sm.n.Stop()

// Clear state, on restart it will be recovered from snapshot or peers
sm.blocksApplied = 0
sm.hash.Reset()
sm.logDebug("Stopped")
}

func (sm *raftChainStateMachine) restart() {
sm.Lock()
defer sm.Unlock()

sm.logDebug("Restarting")

if sm.n.State() != Closed {
return
}

// The filestore is stopped as well, so need to extract the parts to recreate it.
rn := sm.n.(*raft)
fs := rn.wal.(*fileStore)

var err error
sm.cfg.Log, err = newFileStore(fs.fcfg, fs.cfg.StreamConfig)
if err != nil {
panic(err)
}
sm.n, err = sm.s.startRaftNode(globalAccountName, sm.cfg, pprofLabels{})
if err != nil {
panic(err)
}
// Finally restart the driver.
go smLoop(sm)
}

func (sm *raftChainStateMachine) proposeBlock() {
// Track how many blocks this replica proposed
sm.proposalSequence += 1
// Create a block
block := ChainBlock{
Proposer: sm.s.Name(),
ProposerSequence: sm.proposalSequence,
Data: make([]byte, sm.rng.Intn(20)+1),
}
// Data is random bytes
sm.rng.Read(block.Data)
// Serialize as JSON
blockData, err := json.Marshal(block)
if err != nil {
panic(fmt.Sprintf("serialization error: %s", err))
}
sm.logDebug(
"Proposing block <%s, %d, [%dB]>",
block.Proposer,
block.ProposerSequence,
len(block.Data),
)

// Propose (may silently fail if this replica is not leader, or other reasons)
sm.propose(blockData)
}

func (sm *raftChainStateMachine) applyBlock(data []byte) {
// Deserialize block received in JSON format
var block ChainBlock
err := json.Unmarshal(data, &block)
if err != nil {
panic(fmt.Sprintf("deserialization error: %s", err))
}
sm.logDebug("Applying block <%s, %d>", block.Proposer, block.ProposerSequence)

// Hash the data on top of the existing running hash
n, err := sm.hash.Write(block.Data)
if n != len(block.Data) {
panic(fmt.Sprintf("unexpected checksum written %d data block size: %d", n, len(block.Data)))
} else if err != nil {
panic(fmt.Sprintf("checksum error: %s", err))
}

// Track block number
sm.blocksApplied += 1
sm.blocksAppliedSinceSnapshot += 1

sm.logDebug("Hash after %d blocks: %X ", sm.blocksApplied, sm.hash.Sum(nil))
}

func (sm *raftChainStateMachine) getCurrentHash() (uint64, string) {
sm.Lock()
defer sm.Unlock()

// Return the number of blocks applied and the current running hash
return sm.blocksApplied, fmt.Sprintf("%X", sm.hash.Sum(nil))
}

type chainHashSnapshot struct {
SourceNode string
HashData []byte
BlocksCount uint64
}

func (sm *raftChainStateMachine) snapshot() {
sm.Lock()
defer sm.Unlock()

if sm.blocksAppliedSinceSnapshot == 0 {
sm.logDebug("Skip snapshot, no new entries")
return
}

sm.logDebug("Snapshot (with %d blocks applied)", sm.blocksApplied)

// Serialize the internal state of the hash block
serializedHash, err := sm.hash.(encoding.BinaryMarshaler).MarshalBinary()
if err != nil {
panic(fmt.Sprintf("failed to marshal hash: %s", err))
}

// Create snapshot
snapshot := chainHashSnapshot{
SourceNode: fmt.Sprintf("%s (%s)", sm.s.Name(), sm.n.ID()),
HashData: serializedHash,
BlocksCount: sm.blocksApplied,
}

// Serialize snapshot as JSON
snapshotData, err := json.Marshal(snapshot)
if err != nil {
panic(fmt.Sprintf("failed to marshal snapshot: %s", err))
}

// Install it as byte array
err = sm.n.InstallSnapshot(snapshotData)
if err != nil {
panic(fmt.Sprintf("failed to snapshot: %s", err))
}

// Reset counter since last snapshot
sm.blocksAppliedSinceSnapshot = 0
}

func (sm *raftChainStateMachine) loadSnapshot(data []byte) {
// Deserialize snapshot from JSON
var snapshot chainHashSnapshot
err := json.Unmarshal(data, &snapshot)
if err != nil {
panic(fmt.Sprintf("failed to unmarshal snapshot: %s", err))
}

sm.logDebug(
"Applying snapshot (created by %s) taken after %d blocks",
snapshot.SourceNode,
snapshot.BlocksCount,
)

// Load internal hash block state
err = sm.hash.(encoding.BinaryUnmarshaler).UnmarshalBinary(snapshot.HashData)
if err != nil {
panic(fmt.Sprintf("failed to unmarshal hash data: %s", err))
}

// Load block counter
sm.blocksApplied = snapshot.BlocksCount
sm.blocksAppliedSinceSnapshot = 0

sm.logDebug("Hash after snapshot with %d blocks: %X ", sm.blocksApplied, sm.hash.Sum(nil))
}

// Factory function to create RaftChainStateMachine on top of the given server/node
func newRaftChainStateMachine(s *Server, cfg *RaftConfig, n RaftNode) stateMachine {
// Create RNG seed based on server name and node id
var seed int64
for _, c := range []byte(s.Name()) {
seed += int64(c)
}
for _, c := range []byte(n.ID()) {
seed += int64(c)
}
rng := rand.New(rand.NewSource(seed))

// Initialize empty hash block
hashBlock := crc32.NewIEEE()

return &raftChainStateMachine{s: s, n: n, cfg: cfg, rng: rng, hash: hashBlock}
}

0 comments on commit 76267e5

Please sign in to comment.