Skip to content

Commit

Permalink
[Bug] Resolve some issues with RuleFeeds and Rule Downloads not worki…
Browse files Browse the repository at this point in the history
…ng with latest `TeamIDs` and `SigningIDs` (#49)

* lint

* updated return types for global rules and feedrules queries

* add testing coverage for these types

* refactor other references/outputs
  • Loading branch information
radsec committed Apr 5, 2024
1 parent 8ecc610 commit f5ec879
Show file tree
Hide file tree
Showing 10 changed files with 258 additions and 48 deletions.
6 changes: 3 additions & 3 deletions internal/cli/rules/rules-export.go
Expand Up @@ -158,13 +158,13 @@ func getRules(client dynamodb.QueryAPI, callback func(globalrules.GlobalRuleRow)
err = fmt.Errorf("something went wrong querying global rules: %w", inerr)
return
}
if len(*rules) == 0 {
if len(rules) == 0 {
break
}

for _, rule := range *rules {
for _, rule := range rules {
total += 1
err = callback(rule)
err = callback(*rule)
if err != nil {
return
}
Expand Down
4 changes: 2 additions & 2 deletions internal/cli/rules/rules.go
Expand Up @@ -57,13 +57,13 @@ func rules(client dynamodb.QueryAPI, tf flags.TargetFlags, limit int) error {
return fmt.Errorf("failed to get rules: %w", err)
}

ruleCount := len(*rules)
ruleCount := len(rules)
if lastEvaluatedKey != nil {
fmt.Printf("Retrieved more than %d rules:\n", ruleCount)
} else {
fmt.Printf("Retrieved %d rules:\n", ruleCount)
}
for i, rule := range *rules {
for i, rule := range rules {
fmt.Println("----- [", i, "] (", rule.SortKey, ")")
fmt.Printf("%s: %s\n", renderRule(rule.SantaRule), rule.Description)
fmt.Println("")
Expand Down
5 changes: 1 addition & 4 deletions internal/handlers/preflight/clean_sync.go
Expand Up @@ -47,10 +47,7 @@ func (c concreteCleanSyncService) determineCleanSync(machineID string, preflight

func determineCleanSyncByRuleCount(preflightRequest *PreflightRequest) bool {
ruleCount := preflightRequest.BinaryRuleCount + preflightRequest.CertificateRuleCount + preflightRequest.CompilerRuleCount + preflightRequest.TransitiveRuleCount
if ruleCount == 0 {
return true
}
return false
return ruleCount == 0
}

func determineCleanSyncRefresh(timeProvider clock.TimeProvider, machineID string, syncState *syncstate.SyncStateRow) (performCleanSync bool, err error) {
Expand Down
9 changes: 6 additions & 3 deletions internal/handlers/ruledownload/clean_sync.go
Expand Up @@ -41,13 +41,16 @@ func (d concreteGlobalRuleDownloader) handle(machineID string, cursor ruledownlo
nextCursor.SetDynamodbLastEvaluatedKey(lastEvaluatedKey)
}

rules := make([]rules.SantaRule, len(*globalRules))
for i, rule := range *globalRules {
rules := make([]rules.SantaRule, len(globalRules))
for i, rule := range globalRules {
rules[i] = rule.SantaRule
}

return response.APIResponse(
http.StatusOK,
RuledownloadResponse{Rules: DDBRulesToResponseRules(rules), Cursor: &nextCursor},
RuledownloadResponse{
Rules: DDBRulesToResponseRules(rules),
Cursor: &nextCursor,
},
)
}
9 changes: 6 additions & 3 deletions internal/handlers/ruledownload/feed_sync.go
Expand Up @@ -65,13 +65,16 @@ func (d concreteFeedRuleDownloader) handle(machineID string, cursor ruledownload
nextCursor.SetDynamodbLastEvaluatedKey(lastEvaluatedKey)
}

rules := make([]rules.SantaRule, len(*feedRules))
for i, rule := range *feedRules {
rules := make([]rules.SantaRule, len(feedRules))
for i, rule := range feedRules {
rules[i] = rule.SantaRule
}

return response.APIResponse(
http.StatusOK,
RuledownloadResponse{Rules: DDBRulesToResponseRules(rules), Cursor: &nextCursor},
RuledownloadResponse{
Rules: DDBRulesToResponseRules(rules),
Cursor: &nextCursor,
},
)
}
21 changes: 0 additions & 21 deletions pkg/dynamodb/query.go
Expand Up @@ -27,27 +27,6 @@ func query(tableName string, api dynamodbQueryAPI, input *dynamodb.QueryInput) (
ctx, cancel := context.WithTimeout(ctx, defaultTimeout)
defer cancel()

// expressionAttributeValues := map[string]types.AttributeValue{
// ":pk": &types.AttributeValueMemberS{
// Value: partitionKey,
// },
// }
// keyConditionExpression := aws.String("PK = :pk")

// var exclusiveStartKey map[string]types.AttributeValue
// input := &dynamodb.QueryInput{
// TableName: aws.String(tableName),
// ConsistentRead: aws.Bool(consistentRead),
// ExpressionAttributeValues: expressionAttributeValues,
// KeyConditionExpression: keyConditionExpression,
// ExclusiveStartKey: exclusiveStartKey,
// FilterExpression: ,
// Limit: limit,
// }

// // var exclusiveStartKey map[string]types.AttributeValue{}
// exclusiveStartKey, _ = attributevalue.MarshalMap(cursor)

input.TableName = aws.String(tableName)

return api.Query(ctx, input)
Expand Down
102 changes: 102 additions & 0 deletions pkg/dynamodb/query_test.go
@@ -0,0 +1,102 @@
package dynamodb

import (
"context"
"fmt"
"testing"

"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/feature/dynamodb/attributevalue"
"github.com/aws/aws-sdk-go-v2/service/dynamodb"
"github.com/aws/aws-sdk-go-v2/service/dynamodb/types"
"github.com/stretchr/testify/assert"
)

type mockQueryAPI func(ctx context.Context, in *dynamodb.QueryInput, optFns ...func(*dynamodb.Options)) (*dynamodb.QueryOutput, error)

func (m mockQueryAPI) Query(ctx context.Context, in *dynamodb.QueryInput, optFns ...func(*dynamodb.Options)) (*dynamodb.QueryOutput, error) {
return m(ctx, in, optFns...)
}

// Per recommendation from AWS docs: https://aws.github.io/aws-sdk-go-v2/docs/unit-testing/
func Test_query(t *testing.T) {
output, err := query(
"test_table",
mockQueryAPI(func(ctx context.Context, in *dynamodb.QueryInput, optFns ...func(*dynamodb.Options)) (*dynamodb.QueryOutput, error) {
assert.Equal(t, "test_table", *in.TableName)
assert.Equal(t, "PK = :PK AND SK = :SK", aws.ToString(in.KeyConditionExpression))
assert.Equal(t, aws.Int32(5), in.Limit)
assert.Equal(
t,
map[string]string{
"#PK": "PK",
"#SK": "SK",
"#Number": "Number",
},
in.ExpressionAttributeNames,
)
assert.Equal(
t,
map[string]types.AttributeValue{},
in.ExpressionAttributeValues,
)

countItems := 3
items := make([]map[string]types.AttributeValue, countItems)

for i := range items {
items[i] = map[string]types.AttributeValue{
"PK": &types.AttributeValueMemberS{
Value: fmt.Sprintf("PK#AA#%d", i),
},
"SK": &types.AttributeValueMemberS{
Value: fmt.Sprintf("SK#BB#%d", i),
},
"Number": &types.AttributeValueMemberN{
Value: fmt.Sprintf("%d", i),
},
}
}

return &dynamodb.QueryOutput{
Items: items,
Count: int32(countItems),
LastEvaluatedKey: map[string]types.AttributeValue{},
}, nil
}),
&dynamodb.QueryInput{
TableName: aws.String("test_table"),
KeyConditionExpression: aws.String("PK = :PK AND SK = :SK"),
ExpressionAttributeNames: map[string]string{
"#PK": "PK",
"#SK": "SK",
"#Number": "Number",
},
ExpressionAttributeValues: map[string]types.AttributeValue{},
Limit: aws.Int32(5),
},
)

if err != nil {
t.Errorf("Error was not expected: %v", err)
}

if output == nil {
t.Errorf("Output was not expected: %v", output)
}

var items []map[string]interface{}

err = attributevalue.UnmarshalListOfMaps(output.Items, &items)
if err != nil {
t.Errorf("failed to unmarshal result from DynamoDB: %s", err.Error())
}

assert.NotEmpty(t, items)
assert.Equal(t, 3, len(items))
for i, item := range items {
assert.Equal(t, fmt.Sprintf("PK#AA#%d", i), item["PK"])
assert.Equal(t, fmt.Sprintf("SK#BB#%d", i), item["SK"])
assert.Equal(t, float64(i), item["Number"])
}
}
4 changes: 2 additions & 2 deletions pkg/model/feedrules/query.go
Expand Up @@ -20,7 +20,7 @@ func GetPaginatedFeedRules(
limit int,
exclusiveStartKey *dynamodb.PrimaryKey,
) (
items *[]FeedRuleRow,
items []*FeedRuleRow,
lastEvaluatedKey *dynamodb.PrimaryKey,
err error,
) {
Expand Down Expand Up @@ -77,7 +77,7 @@ func GetPaginatedFeedRules(
// log.Printf(" got %d items from query.", len(*items))

// To support legacy SHA256 types, we must transform the datasets before returning
for _, item := range *items {
for _, item := range items {
if item.SHA256 != "" && item.Identifier == "" {
item.Identifier = item.SHA256
}
Expand Down
27 changes: 17 additions & 10 deletions pkg/model/globalrules/query.go
Expand Up @@ -7,6 +7,7 @@ import (
"github.com/airbnb/rudolph/pkg/dynamodb"
"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/feature/dynamodb/attributevalue"
"github.com/aws/aws-sdk-go-v2/feature/dynamodb/expression"
awsdynamodb "github.com/aws/aws-sdk-go-v2/service/dynamodb"
"github.com/aws/aws-sdk-go-v2/service/dynamodb/types"
)
Expand All @@ -21,7 +22,7 @@ func GetPaginatedGlobalRules(
limit int,
exclusiveStartKey *dynamodb.PrimaryKey,
) (
items *[]GlobalRuleRow,
items []*GlobalRuleRow,
lastEvaluatedKey *dynamodb.PrimaryKey,
err error,
) {
Expand All @@ -32,10 +33,10 @@ func GetPaginatedGlobalRules(
return
}

keyConditionExpression := aws.String("PK = :pk")
expressionAttributeValues := map[string]types.AttributeValue{
":pk": &types.AttributeValueMemberS{Value: partitionKey},
}
keyCond := expression.KeyEqual(
expression.Key("PK"), expression.Value(partitionKey),
)

var exclusiveStartKeyInput map[string]types.AttributeValue
if exclusiveStartKey != nil {
exclusiveStartKeyInput, err = attributevalue.MarshalMap(exclusiveStartKey)
Expand All @@ -45,10 +46,16 @@ func GetPaginatedGlobalRules(
}
}

expr, err := expression.NewBuilder().WithKeyCondition(keyCond).Build()
if err != nil {
return
}

input := &awsdynamodb.QueryInput{
ConsistentRead: aws.Bool(false),
ExpressionAttributeValues: expressionAttributeValues,
KeyConditionExpression: keyConditionExpression,
ExpressionAttributeNames: expr.Names(),
ExpressionAttributeValues: expr.Values(),
KeyConditionExpression: expr.KeyCondition(),
ExclusiveStartKey: exclusiveStartKeyInput,
Limit: aws.Int32(int32(limit)),
}
Expand All @@ -64,19 +71,19 @@ func GetPaginatedGlobalRules(
if result.LastEvaluatedKey != nil {
err = attributevalue.UnmarshalMap(result.LastEvaluatedKey, &lastEvaluatedKey)
if err != nil {
err = fmt.Errorf("failed to unmarshall LastEvaluatedKey: %w", err)
err = fmt.Errorf("failed to UnmarshalMap LastEvaluatedKey: %w", err)
return
}
}

err = attributevalue.UnmarshalListOfMaps(result.Items, &items)
if err != nil {
err = fmt.Errorf("failed to unmarshal result from DynamoDB: %w", err)
err = fmt.Errorf("failed to UnmarshalListOfMaps result from DynamoDB: %w", err)
return
}

// To support legacy SHA256 types, we must transform the datasets before returning
for _, item := range *items {
for _, item := range items {
if item.SHA256 != "" && item.Identifier == "" {
item.Identifier = item.SHA256
}
Expand Down

0 comments on commit f5ec879

Please sign in to comment.