Skip to content

Commit

Permalink
Merge pull request #3 from krumIO/feat/multi-region
Browse files Browse the repository at this point in the history
feat: multi region replica strike implementation
  • Loading branch information
grudra7714 committed Oct 26, 2023
2 parents 9496f0e + 914cab4 commit 71a3f4d
Show file tree
Hide file tree
Showing 6 changed files with 181 additions and 35 deletions.
2 changes: 1 addition & 1 deletion cmd/root.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,9 @@ var (
"CCC-Taxonomy": {
Strikes.SQLFeatures,
Strikes.AutomatedBackups,
Strikes.MultiRegion,
// Strikes.VerticalScaling,
// Strikes.Replication,
// Strikes.MultiRegion,
// Strikes.BackupRecovery,
// Strikes.Encryption,
// Strikes.RBAC,
Expand Down
1 change: 1 addition & 0 deletions example-config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ raids:
config:
instance_identifier: unique-id-name
database: test
primary_region: us-east-1
host: localhost
password: password
port: 3306
Expand Down
39 changes: 7 additions & 32 deletions strikes/AutomatedBackups.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ func (a *Strikes) AutomatedBackups() (strikeName string, result raidengine.Strik
Movements: make(map[string]raidengine.MovementResult),
}

// Movement
// Get Configuration
cfg, err := getAWSConfig()
if err != nil {
result.Message = err.Error()
Expand All @@ -37,10 +37,10 @@ func (a *Strikes) AutomatedBackups() (strikeName string, result raidengine.Strik
return
}

autmatedBackupsMovement := checkRDSAutomatedBackupMovement(cfg)
result.Movements["CheckForDBInstanceAutomatedBackups"] = autmatedBackupsMovement
if !autmatedBackupsMovement.Passed {
result.Message = autmatedBackupsMovement.Message
automatedBackupsMovement := checkRDSAutomatedBackupMovement(cfg)
result.Movements["CheckForDBInstanceAutomatedBackups"] = automatedBackupsMovement
if !automatedBackupsMovement.Passed {
result.Message = automatedBackupsMovement.Message
return
}

Expand All @@ -49,31 +49,6 @@ func (a *Strikes) AutomatedBackups() (strikeName string, result raidengine.Strik
return
}

func checkRDSInstanceMovement(cfg aws.Config) (result raidengine.MovementResult) {
// check if the instance is available
result = raidengine.MovementResult{
Description: "Check if the instance is available/exists",
Function: utils.CallerPath(0),
}

rdsClient := rds.NewFromConfig(cfg)
identifier, _ := getDBInstanceIdentifier()

input := &rds.DescribeDBInstancesInput{
DBInstanceIdentifier: aws.String(identifier),
}

instances, err := rdsClient.DescribeDBInstances(context.TODO(), input)
if err != nil {
// Handle error
result.Message = err.Error()
result.Passed = false
return
}
result.Passed = len(instances.DBInstances) > 0
return
}

func checkRDSAutomatedBackupMovement(cfg aws.Config) (result raidengine.MovementResult) {

result = raidengine.MovementResult{
Expand All @@ -82,10 +57,10 @@ func checkRDSAutomatedBackupMovement(cfg aws.Config) (result raidengine.Movement
}

rdsClient := rds.NewFromConfig(cfg)
identifier, _ := getDBInstanceIdentifier()
instanceIdentifier, _ := getHostDBInstanceIdentifier()

input := &rds.DescribeDBInstanceAutomatedBackupsInput{
DBInstanceIdentifier: aws.String(identifier),
DBInstanceIdentifier: aws.String(instanceIdentifier),
}

backups, err := rdsClient.DescribeDBInstanceAutomatedBackups(context.TODO(), input)
Expand Down
98 changes: 98 additions & 0 deletions strikes/MultiRegion.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
package strikes

import (
"github.com/aws/aws-sdk-go-v2/aws"
"github.com/privateerproj/privateer-sdk/raidengine"
"github.com/privateerproj/privateer-sdk/utils"
)

func (a *Strikes) MultiRegion() (strikeName string, result raidengine.StrikeResult) {
strikeName = "MultiRegion"
result = raidengine.StrikeResult{
Passed: false,
Description: "Check if AWS RDS instance has multi region. This strike only checks for a read replica in a seperate region",
DocsURL: "https://www.github.com/krumIO/raid-rds",
ControlID: "CCC-Taxonomy-1",
Movements: make(map[string]raidengine.MovementResult),
}

// Get Configuration
cfg, err := getAWSConfig()
if err != nil {
result.Message = err.Error()
return
}

rdsInstanceMovement := checkRDSInstanceMovement(cfg)
result.Movements["CheckForDBInstance"] = rdsInstanceMovement
if !rdsInstanceMovement.Passed {
result.Message = rdsInstanceMovement.Message
return
}

multiRegionMovement := checkRDSMultiRegionMovement(cfg)
result.Movements["CheckForMultiRegionDBInstances"] = multiRegionMovement
if !multiRegionMovement.Passed {
result.Message = multiRegionMovement.Message
return
}

result.Passed = true
result.Message = "Completed Successfully"

return
}

func checkRDSMultiRegionMovement(cfg aws.Config) (result raidengine.MovementResult) {

result = raidengine.MovementResult{
Description: "Check if the instance has multi region enabled",
Function: utils.CallerPath(0),
}
instanceIdentifier, _ := getHostDBInstanceIdentifier()

instance, _ := getRDSInstanceFromIdentifier(cfg, instanceIdentifier)

// get read replicas from the instance
readReplicas := instance.DBInstances[0].ReadReplicaDBInstanceIdentifiers

if len(readReplicas) == 0 {
result.Passed = false
result.Message = "Multi Region instances not found"
return
}

hostRDSRegion, _ := getHostRDSRegion()

// loop over the read replicas and check if they are in a different region
for _, replica := range readReplicas {
// we are getting the instance identifier the read replicas
// get instance from the replica identifier
replicaInstance, err := getRDSInstanceFromIdentifier(cfg, replica)

if err != nil {
result.Passed = false
result.Message = err.Error()
return
}

if len(replicaInstance.DBInstances) == 0 {
result.Passed = false
result.Message = "Cannot access the replica instance " + replica
return
}

// check if replica region matches the host region
az := *replicaInstance.DBInstances[0].AvailabilityZone
// db instance doesnt contain the region so we need to remove the last character from the az
if az[:len(az)-1] == hostRDSRegion {
result.Passed = false
result.Message = "Multi Region instances not found"
return
}
}

result.Passed = true
return

}
32 changes: 32 additions & 0 deletions strikes/MultiRegion_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
package strikes

import (
"encoding/json"
"fmt"
"testing"

"github.com/spf13/viper"
)

func TestMultiRegion(t *testing.T) {
viper.AddConfigPath("../")
viper.SetConfigName("config")
viper.SetConfigType("yaml")
err := viper.ReadInConfig()

if err != nil {
fmt.Println("Config file not found...")
return
}

strikes := Strikes{}
strikeName, result := strikes.MultiRegion()

fmt.Println(strikeName)
b, err := json.MarshalIndent(result, "", " ")
if err != nil {
fmt.Println(err)
}
fmt.Print(string(b))
fmt.Println()
}
44 changes: 42 additions & 2 deletions strikes/common.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/config"
"github.com/aws/aws-sdk-go-v2/credentials"
"github.com/aws/aws-sdk-go-v2/service/rds"
hclog "github.com/hashicorp/go-hclog"
"github.com/privateerproj/privateer-sdk/raidengine"
"github.com/privateerproj/privateer-sdk/utils"
Expand All @@ -32,17 +33,25 @@ func getDBConfig() (string, error) {
return "", errors.New("database url must be set in the config file")
}

func getDBInstanceIdentifier() (string, error) {
func getHostDBInstanceIdentifier() (string, error) {
if viper.IsSet("raids.RDS.aws.config.instance_identifier") {
return viper.GetString("raids.RDS.aws.config.instance_identifier"), nil
}
return "", errors.New("database instance identifier must be set in the config file")
}

func getHostRDSRegion() (string, error) {
if viper.IsSet("raids.RDS.aws.config.primary_region") {
return viper.GetString("raids.RDS.aws.config.primary_region"), nil
}
return "", errors.New("database instance identifier must be set in the config file")
}

func getAWSConfig() (cfg aws.Config, err error) {
if viper.IsSet("raids.RDS.aws.creds") &&
viper.IsSet("raids.RDS.aws.creds.aws_access_key") &&
viper.IsSet("raids.RDS.aws.creds.aws_secret_key") {
viper.IsSet("raids.RDS.aws.creds.aws_secret_key") &&
viper.IsSet("raids.RDS.aws.creds.aws_region") {

access_key := viper.GetString("raids.RDS.aws.creds.aws_access_key")
secret_key := viper.GetString("raids.RDS.aws.creds.aws_secret_key")
Expand All @@ -68,3 +77,34 @@ func connectToDb() (result raidengine.MovementResult) {
result.Passed = true
return
}

func checkRDSInstanceMovement(cfg aws.Config) (result raidengine.MovementResult) {
// check if the instance is available
result = raidengine.MovementResult{
Description: "Check if the instance is available/exists",
Function: utils.CallerPath(0),
}

instanceIdentifier, _ := getHostDBInstanceIdentifier()

instance, err := getRDSInstanceFromIdentifier(cfg, instanceIdentifier)
if err != nil {
// Handle error
result.Message = err.Error()
result.Passed = false
return
}
result.Passed = len(instance.DBInstances) > 0
return
}

func getRDSInstanceFromIdentifier(cfg aws.Config, identifier string) (instance *rds.DescribeDBInstancesOutput, err error) {
rdsClient := rds.NewFromConfig(cfg)

input := &rds.DescribeDBInstancesInput{
DBInstanceIdentifier: aws.String(identifier),
}

instance, err = rdsClient.DescribeDBInstances(context.TODO(), input)
return
}

0 comments on commit 71a3f4d

Please sign in to comment.