Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add more RAFT tests #4626

Draft
wants to merge 5 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
4 changes: 3 additions & 1 deletion server/jetstream_helpers_test.go
Expand Up @@ -24,6 +24,7 @@ import (
"net"
"net/url"
"os"
"path/filepath"
"strings"
"sync"
"testing"
Expand Down Expand Up @@ -785,8 +786,9 @@ func createJetStreamClusterEx(t testing.TB, tmpl, cName, snPre string, numServer
}

for cp := portStart; cp < portStart+numServers; cp++ {
storeDir := t.TempDir()
sn := fmt.Sprintf("%sS-%d", snPre, cp-portStart+1)
storeDir := filepath.Join(t.TempDir(), sn)
fmt.Printf("Store dir: %s => %s\n", sn, storeDir)
conf := fmt.Sprintf(tmpl, sn, storeDir, cName, cp, routeConfig)
if modify != nil {
conf = modify(sn, cName, storeDir, conf)
Expand Down
5 changes: 5 additions & 0 deletions server/raft.go
Expand Up @@ -1123,6 +1123,7 @@ func (n *raft) setupLastSnapshot() {
sfile := filepath.Join(snapDir, sf.Name())
var term, index uint64
term, index, err := termAndIndexFromSnapFile(sf.Name())
fmt.Printf("Found snapshot [term: %d, index: %d] (%s)\n", term, index, sf.Name())
if err == nil {
if term > lterm {
lterm, lindex = term, index
Expand Down Expand Up @@ -1179,11 +1180,14 @@ func (n *raft) setupLastSnapshot() {
if _, err := n.wal.Compact(snap.lastIndex + 1); err != nil {
n.setWriteErrLocked(err)
}
fmt.Printf("Loaded last snapshot successfully\n")
}

// loadLastSnapshot will load and return our last snapshot.
// Lock should be held.
func (n *raft) loadLastSnapshot() (*snapshot, error) {
fmt.Printf("Loading last snapshot: %s\n", n.snapfile)

if n.snapfile == _EMPTY_ {
return nil, errNoSnapAvailable
}
Expand Down Expand Up @@ -1231,6 +1235,7 @@ func (n *raft) loadLastSnapshot() (*snapshot, error) {
return nil, errSnapshotCorrupt
}

fmt.Printf("Loaded last snapshot: %s\n", n.snapfile)
return snap, nil
}

Expand Down
319 changes: 317 additions & 2 deletions server/raft_helpers_test.go
Expand Up @@ -17,9 +17,14 @@
package server

import (
"encoding"
"encoding/binary"
"encoding/json"
"fmt"
"hash"
"hash/crc32"
"math/rand"
"path/filepath"
"sync"
"testing"
"time"
Expand Down Expand Up @@ -108,12 +113,15 @@ func (c *cluster) createRaftGroupWithPeers(name string, servers []*Server, smf s
}

for _, s := range servers {
baseDir := filepath.Join(c.t.TempDir(), fmt.Sprintf("%s-%s", name, s.Name()))
logDir := filepath.Join(baseDir, "log")
storeDir := filepath.Join(baseDir, "store")
fs, err := newFileStore(
FileStoreConfig{StoreDir: c.t.TempDir(), BlockSize: defaultMediumBlockSize, AsyncFlush: false, SyncInterval: 5 * time.Minute},
FileStoreConfig{StoreDir: logDir, BlockSize: defaultMediumBlockSize, AsyncFlush: false, SyncInterval: 5 * time.Minute},
StreamConfig{Name: name, Storage: FileStorage},
)
require_NoError(c.t, err)
cfg := &RaftConfig{Name: name, Store: c.t.TempDir(), Log: fs}
cfg := &RaftConfig{Name: name, Store: storeDir, Log: fs}
s.bootstrapRaftNode(cfg, peers, true)
n, err := s.startRaftNode(globalAccountName, cfg, pprofLabels{})
require_NoError(c.t, err)
Expand Down Expand Up @@ -274,3 +282,310 @@ func (rg smGroup) waitOnTotal(t *testing.T, expected int64) {
func newStateAdder(s *Server, cfg *RaftConfig, n RaftNode) stateMachine {
return &stateAdder{s: s, n: n, cfg: cfg}
}

var RaftChainOptions = struct {
verbose bool
maxBlockSize int
}{
false,
25,
}

// Simple implementation of a replicated state machine on top of RAFT.
// Hash each value delivered on top of existing hash
// All replicas should go through the same sequence of block hashes
type raftChainStateMachine struct {
sync.Mutex
s *Server
n RaftNode
cfg *RaftConfig
leader bool
proposalSequence uint64
rng *rand.Rand
hash hash.Hash
blocksApplied uint64
blocksAppliedSinceSnapshot uint64
stopped bool
ready bool
safeSnapshots bool
}

// Block is just a random array of bytes, but contains a little extra metadata to track its source
type ChainBlock struct {
Proposer string
ProposerSequence uint64
Data []byte
}

func (sm *raftChainStateMachine) logDebug(format string, args ...any) {
if RaftChainOptions.verbose {
fmt.Printf("["+sm.s.Name()+" ("+sm.n.ID()+")] "+format+"\n", args...)
}
}

func (sm *raftChainStateMachine) server() *Server {
sm.Lock()
defer sm.Unlock()
return sm.s
}

func (sm *raftChainStateMachine) node() RaftNode {
sm.Lock()
defer sm.Unlock()
return sm.n
}

func (sm *raftChainStateMachine) propose(data []byte) {
sm.Lock()
defer sm.Unlock()
if !sm.ready {
sm.logDebug("Refusing to propose during recovery")
}
err := sm.n.ForwardProposal(data)
if err != nil {
sm.logDebug("block proposal error: %s", err)
}
}

func (sm *raftChainStateMachine) applyEntry(ce *CommittedEntry) {
sm.Lock()
defer sm.Unlock()
if ce == nil {
// A nil CE is a signal the previous recovery backlog is over
sm.logDebug("Recovery complete")
sm.ready = true
return
}
sm.logDebug("Apply entries #%d (%d entries)", ce.Index, len(ce.Entries))
for _, entry := range ce.Entries {
if entry.Type == EntryNormal {
sm.applyBlock(entry.Data)
} else if entry.Type == EntrySnapshot {
sm.loadSnapshot(entry.Data)
} else {
panic(fmt.Sprintf("[%s] unknown entry type: %s", sm.s.Name(), entry.Type))
}
}
sm.n.Applied(ce.Index)
}

func (sm *raftChainStateMachine) leaderChange(isLeader bool) {
if sm.leader && !isLeader {
sm.logDebug("Leader change: no longer leader")
} else if sm.leader && isLeader {
sm.logDebug("Elected leader while already leader")
} else if !sm.leader && isLeader {
sm.logDebug("Leader change: i am leader")
} else {
sm.logDebug("Leader change")
}
sm.leader = isLeader
if isLeader != sm.node().Leader() {
sm.logDebug("⚠️ Leader state out of sync with underlying node")
}
}

func (sm *raftChainStateMachine) stop() {
sm.Lock()
defer sm.Unlock()
sm.n.Stop()

// Clear state, on restart it will be recovered from snapshot or peers
sm.stopped = true
sm.blocksApplied = 0
sm.hash.Reset()
sm.leader = false
sm.logDebug("Stopped")
}

func (sm *raftChainStateMachine) restart() {
sm.Lock()
defer sm.Unlock()

sm.logDebug("Restarting")

sm.stopped = false
sm.ready = false
if sm.n.State() != Closed {
return
}

// The filestore is stopped as well, so need to extract the parts to recreate it.
rn := sm.n.(*raft)
fs := rn.wal.(*fileStore)

var err error
sm.cfg.Log, err = newFileStore(fs.fcfg, fs.cfg.StreamConfig)
if err != nil {
panic(err)
}
sm.n, err = sm.s.startRaftNode(globalAccountName, sm.cfg, pprofLabels{})
if err != nil {
panic(err)
}
// Finally restart the driver.
go smLoop(sm)
}

func (sm *raftChainStateMachine) proposeBlock() {
// Track how many blocks this replica proposed
sm.proposalSequence += 1
// Create a block
block := ChainBlock{
Proposer: sm.s.Name(),
ProposerSequence: sm.proposalSequence,
Data: make([]byte, sm.rng.Intn(20)+1),
}
// Data is random bytes
sm.rng.Read(block.Data)
// Serialize as JSON
blockData, err := json.Marshal(block)
if err != nil {
panic(fmt.Sprintf("serialization error: %s", err))
}
sm.logDebug(
"Proposing block <%s, %d, [%dB]>",
block.Proposer,
block.ProposerSequence,
len(block.Data),
)

// Propose (may silently fail if this replica is not leader, or other reasons)
sm.propose(blockData)
}

func (sm *raftChainStateMachine) applyBlock(data []byte) {
// Deserialize block received in JSON format
var block ChainBlock
err := json.Unmarshal(data, &block)
if err != nil {
panic(fmt.Sprintf("deserialization error: %s", err))
}
sm.logDebug("Applying block <%s, %d>", block.Proposer, block.ProposerSequence)

// Hash the data on top of the existing running hash
n, err := sm.hash.Write(block.Data)
if n != len(block.Data) {
panic(fmt.Sprintf("unexpected checksum written %d data block size: %d", n, len(block.Data)))
} else if err != nil {
panic(fmt.Sprintf("checksum error: %s", err))
}

// Track block number
sm.blocksApplied += 1
sm.blocksAppliedSinceSnapshot += 1

sm.logDebug("Hash after %d blocks: %X ", sm.blocksApplied, sm.hash.Sum(nil))
}

func (sm *raftChainStateMachine) getCurrentHash() (bool, uint64, string) {
sm.Lock()
defer sm.Unlock()

// Return running, the number of blocks applied and the current running hash
return !sm.stopped, sm.blocksApplied, fmt.Sprintf("%X", sm.hash.Sum(nil))
}

type chainHashSnapshot struct {
SourceNode string
HashData []byte
BlocksCount uint64
}

func (sm *raftChainStateMachine) snapshot() {
sm.Lock()
defer sm.Unlock()

if sm.blocksAppliedSinceSnapshot == 0 {
sm.logDebug("Skip snapshot, no new entries")
return
}

if sm.safeSnapshots && !sm.ready {
sm.logDebug("Skip snapshot, still recovering")
return
}

sm.logDebug(
"Snapshot (with %d blocks applied, %d since last snapshot)",
sm.blocksApplied,
sm.blocksAppliedSinceSnapshot,
)

// Serialize the internal state of the hash block
serializedHash, err := sm.hash.(encoding.BinaryMarshaler).MarshalBinary()
if err != nil {
panic(fmt.Sprintf("failed to marshal hash: %s", err))
}

// Create snapshot
snapshot := chainHashSnapshot{
SourceNode: fmt.Sprintf("%s (%s)", sm.s.Name(), sm.n.ID()),
HashData: serializedHash,
BlocksCount: sm.blocksApplied,
}

// Serialize snapshot as JSON
snapshotData, err := json.Marshal(snapshot)
if err != nil {
panic(fmt.Sprintf("failed to marshal snapshot: %s", err))
}

// Install it as byte array
err = sm.n.InstallSnapshot(snapshotData)
if err != nil {
panic(fmt.Sprintf("failed to snapshot: %s", err))
}

// Reset counter since last snapshot
sm.blocksAppliedSinceSnapshot = 0
}

func (sm *raftChainStateMachine) loadSnapshot(data []byte) {
// Deserialize snapshot from JSON
var snapshot chainHashSnapshot
err := json.Unmarshal(data, &snapshot)
if err != nil {
panic(fmt.Sprintf("failed to unmarshal snapshot: %s", err))
}

sm.logDebug(
"Applying snapshot (created by %s) taken after %d blocks",
snapshot.SourceNode,
snapshot.BlocksCount,
)

// Load internal hash block state
err = sm.hash.(encoding.BinaryUnmarshaler).UnmarshalBinary(snapshot.HashData)
if err != nil {
panic(fmt.Sprintf("failed to unmarshal hash data: %s", err))
}

// Load block counter
sm.blocksApplied = snapshot.BlocksCount
sm.blocksAppliedSinceSnapshot = 0

sm.logDebug("Hash after snapshot with %d blocks: %X ", sm.blocksApplied, sm.hash.Sum(nil))
}

// Factory function to create RaftChainStateMachine on top of the given server/node
func newRaftChainStateMachine(s *Server, cfg *RaftConfig, n RaftNode) stateMachine {
// Create RNG seed based on server name and node id
var seed int64
for _, c := range []byte(s.Name()) {
seed += int64(c)
}
for _, c := range []byte(n.ID()) {
seed += int64(c)
}
rng := rand.New(rand.NewSource(seed))

// Initialize empty hash block
hashBlock := crc32.NewIEEE()

// Set to true to make RCSM ignore snapshot requests during 'recovery'.
// i.e. after a restart and before a nil commit entry is consumed from the transitions queue.
var safeSnapshots bool

return &raftChainStateMachine{s: s, n: n, cfg: cfg, rng: rng, hash: hashBlock, safeSnapshots: safeSnapshots}
}