Skip to content

Commit

Permalink
awslogs: Prevent close from being blocked on log
Browse files Browse the repository at this point in the history
Before this change a call to `Close` could be blocked if the the channel
used to buffer logs is full.
When this happens the container state will end up wedged causing a
deadlock on anything that needs to lock the container state.

This removes the use of a channel which has semantics which are
difficult to manage to something more suitable for the situation.

Signed-off-by: Brian Goff <cpuguy83@gmail.com>
  • Loading branch information
cpuguy83 committed May 2, 2024
1 parent faf84d7 commit 3821810
Show file tree
Hide file tree
Showing 4 changed files with 293 additions and 48 deletions.
46 changes: 24 additions & 22 deletions daemon/logger/awslogs/cloudwatchlogs.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ import (
"regexp"
"sort"
"strconv"
"sync"
"sync/atomic"
"time"
"unicode/utf8"

Expand Down Expand Up @@ -76,10 +76,11 @@ type logStream struct {
forceFlushInterval time.Duration
multilinePattern *regexp.Regexp
client api
messages chan *logger.Message
lock sync.RWMutex
closed bool
sequenceToken *string

messages *loggerutils.MessageQueue
closed atomic.Bool

sequenceToken *string
}

type logStreamConfig struct {
Expand Down Expand Up @@ -158,7 +159,7 @@ func New(info logger.Info) (logger.Logger, error) {
forceFlushInterval: containerStreamConfig.forceFlushInterval,
multilinePattern: containerStreamConfig.multilinePattern,
client: client,
messages: make(chan *logger.Message, containerStreamConfig.maxBufferedEvents),
messages: loggerutils.NewMessageQueue(containerStreamConfig.maxBufferedEvents),
}

creationDone := make(chan bool)
Expand All @@ -168,12 +169,10 @@ func New(info logger.Info) (logger.Logger, error) {
maxBackoff := 32
for {
// If logger is closed we are done
containerStream.lock.RLock()
if containerStream.closed {
containerStream.lock.RUnlock()
if containerStream.closed.Load() {
break
}
containerStream.lock.RUnlock()

err := containerStream.create()
if err == nil {
break
Expand Down Expand Up @@ -426,25 +425,26 @@ func (l *logStream) BufSize() int {
return maximumBytesPerEvent
}

var errClosed = errors.New("awslogs is closed")

// Log submits messages for logging by an instance of the awslogs logging driver
func (l *logStream) Log(msg *logger.Message) error {
l.lock.RLock()
defer l.lock.RUnlock()
if l.closed {
return errors.New("awslogs is closed")
// No need to check if we are closed here since the queue will be closed
// (i.e. returns false) in this case.
ctx := context.TODO()
if err := l.messages.Enqueue(ctx, msg); err != nil {
if err == loggerutils.ErrQueueClosed {
return errClosed
}
return err
}
l.messages <- msg
return nil
}

// Close closes the instance of the awslogs logging driver
func (l *logStream) Close() error {
l.lock.Lock()
defer l.lock.Unlock()
if !l.closed {
close(l.messages)
}
l.closed = true
l.closed.Store(true)
l.messages.Close()
return nil
}

Expand Down Expand Up @@ -561,6 +561,8 @@ func (l *logStream) collectBatch(created chan bool) {
var eventBuffer []byte
var eventBufferTimestamp int64
batch := newEventBatch()

chLogs := l.messages.Receiver()
for {
select {
case t := <-ticker.C:
Expand All @@ -576,7 +578,7 @@ func (l *logStream) collectBatch(created chan bool) {
}
l.publishBatch(batch)
batch.reset()
case msg, more := <-l.messages:
case msg, more := <-chLogs:
if !more {
// Flush event buffer and release resources
l.processEvent(batch, eventBuffer, eventBufferTimestamp)
Expand Down
55 changes: 29 additions & 26 deletions daemon/logger/awslogs/cloudwatchlogs_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -356,9 +356,10 @@ func TestCreateAlreadyExists(t *testing.T) {
func TestLogClosed(t *testing.T) {
mockClient := &mockClient{}
stream := &logStream{
client: mockClient,
closed: true,
client: mockClient,
messages: loggerutils.NewMessageQueue(0),
}
stream.Close()
err := stream.Log(&logger.Message{})
assert.Check(t, err != nil)
}
Expand All @@ -370,7 +371,7 @@ func TestLogBlocking(t *testing.T) {
mockClient := &mockClient{}
stream := &logStream{
client: mockClient,
messages: make(chan *logger.Message),
messages: loggerutils.NewMessageQueue(0),
}

errorCh := make(chan error, 1)
Expand All @@ -387,14 +388,11 @@ func TestLogBlocking(t *testing.T) {
t.Fatal("Expected stream.Log to block: ", err)
default:
}

// assuming it is blocked, we can now try to drain the internal channel and
// unblock it
select {
case <-time.After(10 * time.Millisecond):
// if we're unable to drain the channel within 10ms, something seems broken
t.Fatal("Expected to be able to read from stream.messages but was unable to")
case <-stream.messages:
}
<-stream.messages.Receiver()

select {
case err := <-errorCh:
assert.NilError(t, err)
Expand All @@ -408,7 +406,7 @@ func TestLogBufferEmpty(t *testing.T) {
mockClient := &mockClient{}
stream := &logStream{
client: mockClient,
messages: make(chan *logger.Message, 1),
messages: loggerutils.NewMessageQueue(1),
}
err := stream.Log(&logger.Message{})
assert.NilError(t, err)
Expand Down Expand Up @@ -556,7 +554,7 @@ func TestCollectBatchSimple(t *testing.T) {
logGroupName: groupName,
logStreamName: streamName,
sequenceToken: aws.String(sequenceToken),
messages: make(chan *logger.Message),
messages: loggerutils.NewMessageQueue(0),
}
calls := make([]*cloudwatchlogs.PutLogEventsInput, 0)
mockClient.putLogEventsFunc = func(ctx context.Context, input *cloudwatchlogs.PutLogEventsInput, opts ...func(*cloudwatchlogs.Options)) (*cloudwatchlogs.PutLogEventsOutput, error) {
Expand All @@ -575,15 +573,20 @@ func TestCollectBatchSimple(t *testing.T) {
close(d)
go stream.collectBatch(d)

stream.Log(&logger.Message{
err := stream.Log(&logger.Message{
Line: []byte(logline),
Timestamp: time.Time{},
})
assert.NilError(t, err)

ticks <- time.Time{}
ticks <- time.Time{}
stream.Close()

for len(calls) != 1 {
time.Sleep(10 * time.Millisecond)
}

assert.Assert(t, len(calls) == 1)
argument := calls[0]
assert.Assert(t, argument != nil)
Expand All @@ -598,7 +601,7 @@ func TestCollectBatchTicker(t *testing.T) {
logGroupName: groupName,
logStreamName: streamName,
sequenceToken: aws.String(sequenceToken),
messages: make(chan *logger.Message),
messages: loggerutils.NewMessageQueue(0),
}
calls := make([]*cloudwatchlogs.PutLogEventsInput, 0)
called := make(chan struct{}, 50)
Expand Down Expand Up @@ -666,7 +669,7 @@ func TestCollectBatchMultilinePattern(t *testing.T) {
logStreamName: streamName,
multilinePattern: multilinePattern,
sequenceToken: aws.String(sequenceToken),
messages: make(chan *logger.Message),
messages: loggerutils.NewMessageQueue(0),
}
calls := make([]*cloudwatchlogs.PutLogEventsInput, 0)
called := make(chan struct{}, 50)
Expand Down Expand Up @@ -732,7 +735,7 @@ func BenchmarkCollectBatch(b *testing.B) {
logGroupName: groupName,
logStreamName: streamName,
sequenceToken: aws.String(sequenceToken),
messages: make(chan *logger.Message),
messages: loggerutils.NewMessageQueue(0),
}
mockClient.putLogEventsFunc = func(ctx context.Context, input *cloudwatchlogs.PutLogEventsInput, opts ...func(*cloudwatchlogs.Options)) (*cloudwatchlogs.PutLogEventsOutput, error) {
return &cloudwatchlogs.PutLogEventsOutput{
Expand Down Expand Up @@ -765,7 +768,7 @@ func BenchmarkCollectBatchMultilinePattern(b *testing.B) {
logStreamName: streamName,
multilinePattern: multilinePattern,
sequenceToken: aws.String(sequenceToken),
messages: make(chan *logger.Message),
messages: loggerutils.NewMessageQueue(0),
}
mockClient.putLogEventsFunc = func(ctx context.Context, input *cloudwatchlogs.PutLogEventsInput, opts ...func(*cloudwatchlogs.Options)) (*cloudwatchlogs.PutLogEventsOutput, error) {
return &cloudwatchlogs.PutLogEventsOutput{
Expand Down Expand Up @@ -796,7 +799,7 @@ func TestCollectBatchMultilinePatternMaxEventAge(t *testing.T) {
logStreamName: streamName,
multilinePattern: multilinePattern,
sequenceToken: aws.String(sequenceToken),
messages: make(chan *logger.Message),
messages: loggerutils.NewMessageQueue(0),
}
calls := make([]*cloudwatchlogs.PutLogEventsInput, 0)
called := make(chan struct{}, 50)
Expand Down Expand Up @@ -870,7 +873,7 @@ func TestCollectBatchMultilinePatternNegativeEventAge(t *testing.T) {
logStreamName: streamName,
multilinePattern: multilinePattern,
sequenceToken: aws.String(sequenceToken),
messages: make(chan *logger.Message),
messages: loggerutils.NewMessageQueue(0),
}
calls := make([]*cloudwatchlogs.PutLogEventsInput, 0)
called := make(chan struct{}, 50)
Expand Down Expand Up @@ -927,7 +930,7 @@ func TestCollectBatchMultilinePatternMaxEventSize(t *testing.T) {
logStreamName: streamName,
multilinePattern: multilinePattern,
sequenceToken: aws.String(sequenceToken),
messages: make(chan *logger.Message),
messages: loggerutils.NewMessageQueue(0),
}
calls := make([]*cloudwatchlogs.PutLogEventsInput, 0)
called := make(chan struct{}, 50)
Expand Down Expand Up @@ -987,7 +990,7 @@ func TestCollectBatchClose(t *testing.T) {
logGroupName: groupName,
logStreamName: streamName,
sequenceToken: aws.String(sequenceToken),
messages: make(chan *logger.Message),
messages: loggerutils.NewMessageQueue(0),
}
calls := make([]*cloudwatchlogs.PutLogEventsInput, 0)
called := make(chan struct{}, 50)
Expand Down Expand Up @@ -1089,7 +1092,7 @@ func TestCollectBatchLineSplit(t *testing.T) {
logGroupName: groupName,
logStreamName: streamName,
sequenceToken: aws.String(sequenceToken),
messages: make(chan *logger.Message),
messages: loggerutils.NewMessageQueue(0),
}
calls := make([]*cloudwatchlogs.PutLogEventsInput, 0)
called := make(chan struct{}, 50)
Expand Down Expand Up @@ -1137,7 +1140,7 @@ func TestCollectBatchLineSplitWithBinary(t *testing.T) {
logGroupName: groupName,
logStreamName: streamName,
sequenceToken: aws.String(sequenceToken),
messages: make(chan *logger.Message),
messages: loggerutils.NewMessageQueue(0),
}
calls := make([]*cloudwatchlogs.PutLogEventsInput, 0)
called := make(chan struct{}, 50)
Expand Down Expand Up @@ -1185,7 +1188,7 @@ func TestCollectBatchMaxEvents(t *testing.T) {
logGroupName: groupName,
logStreamName: streamName,
sequenceToken: aws.String(sequenceToken),
messages: make(chan *logger.Message),
messages: loggerutils.NewMessageQueue(0),
}
calls := make([]*cloudwatchlogs.PutLogEventsInput, 0)
called := make(chan struct{}, 50)
Expand Down Expand Up @@ -1239,7 +1242,7 @@ func TestCollectBatchMaxTotalBytes(t *testing.T) {
logGroupName: groupName,
logStreamName: streamName,
sequenceToken: aws.String(sequenceToken),
messages: make(chan *logger.Message),
messages: loggerutils.NewMessageQueue(0),
}
calls := make([]*cloudwatchlogs.PutLogEventsInput, 0)
called := make(chan struct{}, 50)
Expand Down Expand Up @@ -1320,7 +1323,7 @@ func TestCollectBatchMaxTotalBytesWithBinary(t *testing.T) {
logGroupName: groupName,
logStreamName: streamName,
sequenceToken: aws.String(sequenceToken),
messages: make(chan *logger.Message),
messages: loggerutils.NewMessageQueue(0),
}
calls := make([]*cloudwatchlogs.PutLogEventsInput, 0)
called := make(chan struct{}, 50)
Expand Down Expand Up @@ -1394,7 +1397,7 @@ func TestCollectBatchWithDuplicateTimestamps(t *testing.T) {
logGroupName: groupName,
logStreamName: streamName,
sequenceToken: aws.String(sequenceToken),
messages: make(chan *logger.Message),
messages: loggerutils.NewMessageQueue(0),
}
calls := make([]*cloudwatchlogs.PutLogEventsInput, 0)
called := make(chan struct{}, 50)
Expand Down

0 comments on commit 3821810

Please sign in to comment.