Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

awslogs: Prevent close from being blocked on log #47748

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
46 changes: 24 additions & 22 deletions daemon/logger/awslogs/cloudwatchlogs.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ import (
"regexp"
"sort"
"strconv"
"sync"
"sync/atomic"
"time"
"unicode/utf8"

Expand Down Expand Up @@ -76,10 +76,11 @@ type logStream struct {
forceFlushInterval time.Duration
multilinePattern *regexp.Regexp
client api
messages chan *logger.Message
lock sync.RWMutex
closed bool
sequenceToken *string

messages *loggerutils.MessageQueue
closed atomic.Bool

sequenceToken *string
}

type logStreamConfig struct {
Expand Down Expand Up @@ -158,7 +159,7 @@ func New(info logger.Info) (logger.Logger, error) {
forceFlushInterval: containerStreamConfig.forceFlushInterval,
multilinePattern: containerStreamConfig.multilinePattern,
client: client,
messages: make(chan *logger.Message, containerStreamConfig.maxBufferedEvents),
messages: loggerutils.NewMessageQueue(containerStreamConfig.maxBufferedEvents),
}

creationDone := make(chan bool)
Expand All @@ -168,12 +169,10 @@ func New(info logger.Info) (logger.Logger, error) {
maxBackoff := 32
for {
// If logger is closed we are done
containerStream.lock.RLock()
if containerStream.closed {
containerStream.lock.RUnlock()
if containerStream.closed.Load() {
break
}
containerStream.lock.RUnlock()

err := containerStream.create()
if err == nil {
break
Expand Down Expand Up @@ -426,25 +425,26 @@ func (l *logStream) BufSize() int {
return maximumBytesPerEvent
}

var errClosed = errors.New("awslogs is closed")

// Log submits messages for logging by an instance of the awslogs logging driver
func (l *logStream) Log(msg *logger.Message) error {
l.lock.RLock()
defer l.lock.RUnlock()
if l.closed {
return errors.New("awslogs is closed")
// No need to check if we are closed here since the queue will be closed
// (i.e. returns false) in this case.
ctx := context.TODO()
if err := l.messages.Enqueue(ctx, msg); err != nil {
if err == loggerutils.ErrQueueClosed {
return errClosed
}
return err
}
l.messages <- msg
return nil
}

// Close closes the instance of the awslogs logging driver
func (l *logStream) Close() error {
l.lock.Lock()
defer l.lock.Unlock()
if !l.closed {
close(l.messages)
}
l.closed = true
l.closed.Store(true)
l.messages.Close()
return nil
}

Expand Down Expand Up @@ -561,6 +561,8 @@ func (l *logStream) collectBatch(created chan bool) {
var eventBuffer []byte
var eventBufferTimestamp int64
batch := newEventBatch()

chLogs := l.messages.Receiver()
for {
select {
case t := <-ticker.C:
Expand All @@ -576,7 +578,7 @@ func (l *logStream) collectBatch(created chan bool) {
}
l.publishBatch(batch)
batch.reset()
case msg, more := <-l.messages:
case msg, more := <-chLogs:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The channel will be closed once the message queue is closed, so any buffered messages will not be handled by the current read implementation. This behavior existed before though so perhaps it should be a seperate issue.

Admittedly I also cannot think of a clean solution to abstract the complexity of reading from the underlying channel of the message queue structure here.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should only return after the buffer is emptied (more is true only after the last message is drained).
https://go.dev/play/p/NR4WOn-XUCs

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh nice! TIL, the wording in https://go.dev/tour/concurrency/4 threw me off, but your example proves the correct behavior.

if !more {
// Flush event buffer and release resources
l.processEvent(batch, eventBuffer, eventBufferTimestamp)
Expand Down
55 changes: 29 additions & 26 deletions daemon/logger/awslogs/cloudwatchlogs_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -356,9 +356,10 @@ func TestCreateAlreadyExists(t *testing.T) {
func TestLogClosed(t *testing.T) {
mockClient := &mockClient{}
stream := &logStream{
client: mockClient,
closed: true,
client: mockClient,
messages: loggerutils.NewMessageQueue(0),
}
stream.Close()
err := stream.Log(&logger.Message{})
assert.Check(t, err != nil)
}
Expand All @@ -370,7 +371,7 @@ func TestLogBlocking(t *testing.T) {
mockClient := &mockClient{}
stream := &logStream{
client: mockClient,
messages: make(chan *logger.Message),
messages: loggerutils.NewMessageQueue(0),
}

errorCh := make(chan error, 1)
Expand All @@ -387,14 +388,11 @@ func TestLogBlocking(t *testing.T) {
t.Fatal("Expected stream.Log to block: ", err)
default:
}

// assuming it is blocked, we can now try to drain the internal channel and
// unblock it
select {
case <-time.After(10 * time.Millisecond):
// if we're unable to drain the channel within 10ms, something seems broken
t.Fatal("Expected to be able to read from stream.messages but was unable to")
case <-stream.messages:
}
<-stream.messages.Receiver()

select {
case err := <-errorCh:
assert.NilError(t, err)
Expand All @@ -408,7 +406,7 @@ func TestLogBufferEmpty(t *testing.T) {
mockClient := &mockClient{}
stream := &logStream{
client: mockClient,
messages: make(chan *logger.Message, 1),
messages: loggerutils.NewMessageQueue(1),
}
err := stream.Log(&logger.Message{})
assert.NilError(t, err)
Expand Down Expand Up @@ -556,7 +554,7 @@ func TestCollectBatchSimple(t *testing.T) {
logGroupName: groupName,
logStreamName: streamName,
sequenceToken: aws.String(sequenceToken),
messages: make(chan *logger.Message),
messages: loggerutils.NewMessageQueue(0),
}
calls := make([]*cloudwatchlogs.PutLogEventsInput, 0)
mockClient.putLogEventsFunc = func(ctx context.Context, input *cloudwatchlogs.PutLogEventsInput, opts ...func(*cloudwatchlogs.Options)) (*cloudwatchlogs.PutLogEventsOutput, error) {
Expand All @@ -575,15 +573,20 @@ func TestCollectBatchSimple(t *testing.T) {
close(d)
go stream.collectBatch(d)

stream.Log(&logger.Message{
err := stream.Log(&logger.Message{
Line: []byte(logline),
Timestamp: time.Time{},
})
assert.NilError(t, err)

ticks <- time.Time{}
ticks <- time.Time{}
stream.Close()

for len(calls) != 1 {
time.Sleep(10 * time.Millisecond)
}

assert.Assert(t, len(calls) == 1)
argument := calls[0]
assert.Assert(t, argument != nil)
Expand All @@ -598,7 +601,7 @@ func TestCollectBatchTicker(t *testing.T) {
logGroupName: groupName,
logStreamName: streamName,
sequenceToken: aws.String(sequenceToken),
messages: make(chan *logger.Message),
messages: loggerutils.NewMessageQueue(0),
}
calls := make([]*cloudwatchlogs.PutLogEventsInput, 0)
called := make(chan struct{}, 50)
Expand Down Expand Up @@ -666,7 +669,7 @@ func TestCollectBatchMultilinePattern(t *testing.T) {
logStreamName: streamName,
multilinePattern: multilinePattern,
sequenceToken: aws.String(sequenceToken),
messages: make(chan *logger.Message),
messages: loggerutils.NewMessageQueue(0),
}
calls := make([]*cloudwatchlogs.PutLogEventsInput, 0)
called := make(chan struct{}, 50)
Expand Down Expand Up @@ -732,7 +735,7 @@ func BenchmarkCollectBatch(b *testing.B) {
logGroupName: groupName,
logStreamName: streamName,
sequenceToken: aws.String(sequenceToken),
messages: make(chan *logger.Message),
messages: loggerutils.NewMessageQueue(0),
}
mockClient.putLogEventsFunc = func(ctx context.Context, input *cloudwatchlogs.PutLogEventsInput, opts ...func(*cloudwatchlogs.Options)) (*cloudwatchlogs.PutLogEventsOutput, error) {
return &cloudwatchlogs.PutLogEventsOutput{
Expand Down Expand Up @@ -765,7 +768,7 @@ func BenchmarkCollectBatchMultilinePattern(b *testing.B) {
logStreamName: streamName,
multilinePattern: multilinePattern,
sequenceToken: aws.String(sequenceToken),
messages: make(chan *logger.Message),
messages: loggerutils.NewMessageQueue(0),
}
mockClient.putLogEventsFunc = func(ctx context.Context, input *cloudwatchlogs.PutLogEventsInput, opts ...func(*cloudwatchlogs.Options)) (*cloudwatchlogs.PutLogEventsOutput, error) {
return &cloudwatchlogs.PutLogEventsOutput{
Expand Down Expand Up @@ -796,7 +799,7 @@ func TestCollectBatchMultilinePatternMaxEventAge(t *testing.T) {
logStreamName: streamName,
multilinePattern: multilinePattern,
sequenceToken: aws.String(sequenceToken),
messages: make(chan *logger.Message),
messages: loggerutils.NewMessageQueue(0),
}
calls := make([]*cloudwatchlogs.PutLogEventsInput, 0)
called := make(chan struct{}, 50)
Expand Down Expand Up @@ -870,7 +873,7 @@ func TestCollectBatchMultilinePatternNegativeEventAge(t *testing.T) {
logStreamName: streamName,
multilinePattern: multilinePattern,
sequenceToken: aws.String(sequenceToken),
messages: make(chan *logger.Message),
messages: loggerutils.NewMessageQueue(0),
}
calls := make([]*cloudwatchlogs.PutLogEventsInput, 0)
called := make(chan struct{}, 50)
Expand Down Expand Up @@ -927,7 +930,7 @@ func TestCollectBatchMultilinePatternMaxEventSize(t *testing.T) {
logStreamName: streamName,
multilinePattern: multilinePattern,
sequenceToken: aws.String(sequenceToken),
messages: make(chan *logger.Message),
messages: loggerutils.NewMessageQueue(0),
}
calls := make([]*cloudwatchlogs.PutLogEventsInput, 0)
called := make(chan struct{}, 50)
Expand Down Expand Up @@ -987,7 +990,7 @@ func TestCollectBatchClose(t *testing.T) {
logGroupName: groupName,
logStreamName: streamName,
sequenceToken: aws.String(sequenceToken),
messages: make(chan *logger.Message),
messages: loggerutils.NewMessageQueue(0),
}
calls := make([]*cloudwatchlogs.PutLogEventsInput, 0)
called := make(chan struct{}, 50)
Expand Down Expand Up @@ -1089,7 +1092,7 @@ func TestCollectBatchLineSplit(t *testing.T) {
logGroupName: groupName,
logStreamName: streamName,
sequenceToken: aws.String(sequenceToken),
messages: make(chan *logger.Message),
messages: loggerutils.NewMessageQueue(0),
}
calls := make([]*cloudwatchlogs.PutLogEventsInput, 0)
called := make(chan struct{}, 50)
Expand Down Expand Up @@ -1137,7 +1140,7 @@ func TestCollectBatchLineSplitWithBinary(t *testing.T) {
logGroupName: groupName,
logStreamName: streamName,
sequenceToken: aws.String(sequenceToken),
messages: make(chan *logger.Message),
messages: loggerutils.NewMessageQueue(0),
}
calls := make([]*cloudwatchlogs.PutLogEventsInput, 0)
called := make(chan struct{}, 50)
Expand Down Expand Up @@ -1185,7 +1188,7 @@ func TestCollectBatchMaxEvents(t *testing.T) {
logGroupName: groupName,
logStreamName: streamName,
sequenceToken: aws.String(sequenceToken),
messages: make(chan *logger.Message),
messages: loggerutils.NewMessageQueue(0),
}
calls := make([]*cloudwatchlogs.PutLogEventsInput, 0)
called := make(chan struct{}, 50)
Expand Down Expand Up @@ -1239,7 +1242,7 @@ func TestCollectBatchMaxTotalBytes(t *testing.T) {
logGroupName: groupName,
logStreamName: streamName,
sequenceToken: aws.String(sequenceToken),
messages: make(chan *logger.Message),
messages: loggerutils.NewMessageQueue(0),
}
calls := make([]*cloudwatchlogs.PutLogEventsInput, 0)
called := make(chan struct{}, 50)
Expand Down Expand Up @@ -1320,7 +1323,7 @@ func TestCollectBatchMaxTotalBytesWithBinary(t *testing.T) {
logGroupName: groupName,
logStreamName: streamName,
sequenceToken: aws.String(sequenceToken),
messages: make(chan *logger.Message),
messages: loggerutils.NewMessageQueue(0),
}
calls := make([]*cloudwatchlogs.PutLogEventsInput, 0)
called := make(chan struct{}, 50)
Expand Down Expand Up @@ -1394,7 +1397,7 @@ func TestCollectBatchWithDuplicateTimestamps(t *testing.T) {
logGroupName: groupName,
logStreamName: streamName,
sequenceToken: aws.String(sequenceToken),
messages: make(chan *logger.Message),
messages: loggerutils.NewMessageQueue(0),
}
calls := make([]*cloudwatchlogs.PutLogEventsInput, 0)
called := make(chan struct{}, 50)
Expand Down