Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,7 @@ protected override async Task<int> InternalRunAsync(CancellationToken cancellati
}
}

AsynchronousMessageBus concreteMessageBusService = new(
var concreteMessageBusService = new AsynchronousMessageBus(
[.. dataConsumersBuilder],
ServiceProvider.GetTestApplicationCancellationTokenSource(),
ServiceProvider.GetTask(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,9 @@ namespace Microsoft.Testing.Platform.Messages;
[DebuggerDisplay("DataConsumer = {DataConsumer.Uid}")]
internal sealed class AsyncConsumerDataProcessor : IAsyncConsumerDataProcessor
{
private readonly ITask _task;
private readonly CancellationToken _cancellationToken;
private readonly Channel<(IDataProducer DataProducer, IData Data)> _channel = Channel.CreateUnbounded<(IDataProducer DataProducer, IData Data)>(new UnboundedChannelOptions

private readonly Channel<AsyncConsumerDataProcessorMessage> _channel = Channel.CreateUnbounded<AsyncConsumerDataProcessorMessage>(new UnboundedChannelOptions
{
// We process only 1 data at a time
SingleReader = true,
Expand All @@ -27,27 +27,27 @@ internal sealed class AsyncConsumerDataProcessor : IAsyncConsumerDataProcessor
AllowSynchronousContinuations = false,
});

// This is needed to avoid possible race condition between drain and _totalPayloadProcessed race condition.
// This is the "logical" consume workflow state.
private readonly TaskCompletionSource _consumerState = new();
private readonly Task _consumeTask;
private long _totalPayloadReceived;
private long _totalPayloadProcessed;

// Number of data payloads enqueued via PublishAsync. The message bus reads this via
// ReceivedCount to detect publisher/consumer cycles across drain rounds.
private long _receivedCount;

public AsyncConsumerDataProcessor(IDataConsumer consumer, ITask task, CancellationToken cancellationToken)
{
DataConsumer = consumer;
_task = task;
_cancellationToken = cancellationToken;
_consumeTask = task.Run(ConsumeAsync, cancellationToken);
}

public IDataConsumer DataConsumer { get; }

public long ReceivedCount => Volatile.Read(ref _receivedCount);

public async Task PublishAsync(IDataProducer dataProducer, IData data)
{
Interlocked.Increment(ref _totalPayloadReceived);
await _channel.Writer.WriteAsync((dataProducer, data), _cancellationToken).ConfigureAwait(false);
Interlocked.Increment(ref _receivedCount);
await _channel.Writer.WriteAsync(AsyncConsumerDataProcessorMessage.CreateData(dataProducer, data), _cancellationToken).ConfigureAwait(false);
}

private async Task ConsumeAsync()
Expand All @@ -56,59 +56,30 @@ private async Task ConsumeAsync()
{
while (await _channel.Reader.WaitToReadAsync(_cancellationToken).ConfigureAwait(false))
{
(IDataProducer dataProducer, IData data) = await _channel.Reader.ReadAsync(_cancellationToken).ConfigureAwait(false);
AsyncConsumerDataProcessorMessage message = await _channel.Reader.ReadAsync(_cancellationToken).ConfigureAwait(false);

try
if (message.DrainMarker is { } drainMarker)
{
// We don't enqueue the data if the consumer is the producer of the data.
// We could optimize this if and make a get with type/all but producers, but it
// could be over-engineering.
if (dataProducer.Uid == DataConsumer.Uid)
{
continue;
}

try
{
await DataConsumer.ConsumeAsync(dataProducer, data, _cancellationToken).ConfigureAwait(false);
}

// We let the catch below to handle the graceful cancellation of the process
catch (Exception ex) when (ex is not OperationCanceledException)
{
// If we're draining before to increment the _totalPayloadProcessed we need to signal that we should throw because
// it's possible we have a race condition where the payload counting in DrainDataAsync returns false and the current task is not yet in a
// "faulted state".
_consumerState.SetException(ex);

// We let current task to move to fault state, checked inside CompleteAddingAsync.
throw;
}
// The drain marker passed all the data items previously enqueued so we can signal the drain caller.
drainMarker.TrySetResult(true);
continue;
}
finally

// We don't enqueue the data if the consumer is the producer of the data.
// We could optimize this if and make a get with type/all but producers, but it
// could be over-engineering.
if (message.DataProducer!.Uid == DataConsumer.Uid)
{
Interlocked.Increment(ref _totalPayloadProcessed);
continue;
}

await DataConsumer.ConsumeAsync(message.DataProducer, message.Data!, _cancellationToken).ConfigureAwait(false);
}
}
catch (OperationCanceledException oc) when (oc.CancellationToken == _cancellationToken)
{
// Ignore we're shutting down
}
catch (Exception ex)
{
// For all other exception we signal the state if not already faulted
if (!_consumerState.Task.IsFaulted)
{
_consumerState.SetException(ex);
}

// let the exception bubble up
throw;
}

// We're exiting gracefully, signal the correct state.
_consumerState.SetResult();
}

public async Task CompleteAddingAsync()
Expand All @@ -121,43 +92,37 @@ public async Task CompleteAddingAsync()
await _consumeTask.ConfigureAwait(false);
}

public async Task<long> DrainDataAsync()
public async Task DrainDataAsync()
{
// We go volatile because we race with Interlocked.Increment in PublishAsync
long totalPayloadProcessed = Volatile.Read(ref _totalPayloadProcessed);
long totalPayloadReceived = Volatile.Read(ref _totalPayloadReceived);
const int minDelayTimeMs = 25;
int currentDelayTimeMs = minDelayTimeMs;
while (Interlocked.CompareExchange(ref _totalPayloadReceived, totalPayloadReceived, totalPayloadProcessed) != totalPayloadProcessed)
{
// When we cancel we throw inside ConsumeAsync and we won't drain anymore any data
if (_cancellationToken.IsCancellationRequested)
{
break;
}
var drainMarker = new TaskCompletionSource<bool>(TaskCreationOptions.RunContinuationsAsynchronously);

await _task.Delay(currentDelayTimeMs).ConfigureAwait(false);
currentDelayTimeMs = Math.Min(currentDelayTimeMs + minDelayTimeMs, 200);

if (_consumerState.Task.IsFaulted)
{
// Rethrow the exception
await _consumerState.Task.ConfigureAwait(false);
}

// Wait for the consumer to complete the current enqueued items
totalPayloadProcessed = Volatile.Read(ref _totalPayloadProcessed);
totalPayloadReceived = Volatile.Read(ref _totalPayloadReceived);
try
{
await _channel.Writer.WriteAsync(AsyncConsumerDataProcessorMessage.CreateDrainMarker(drainMarker), _cancellationToken).ConfigureAwait(false);
}

// It' possible that we fail and we have consumed the item
if (_consumerState.Task.IsFaulted)
catch (ChannelClosedException)
{
// Rethrow the exception
await _consumerState.Task.ConfigureAwait(false);
// The channel was already completed (e.g., by DisableAsync). Nothing left to drain.
return;
}
catch (OperationCanceledException oc) when (oc.CancellationToken == _cancellationToken)
{
// The application is shutting down. Treat the drain as a graceful no-op,
// matching the previous behavior of bailing out of DrainDataAsync on cancellation.
return;
}

return _totalPayloadReceived;
// Wait either for the drain marker to be dequeued, or for the consume task to finish/fault.
// If the consume task ends before the marker is reached, propagate any failure it surfaced.
Task completed = await Task.WhenAny(drainMarker.Task, _consumeTask).ConfigureAwait(false);
if (completed == _consumeTask)
{
await _consumeTask.ConfigureAwait(false);
}
else
{
await drainMarker.Task.ConfigureAwait(false);
}
}

// At this point we simply signal the channel as complete and we don't wait for the consumer to complete.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,32 +10,31 @@ namespace Microsoft.Testing.Platform.Messages;

internal sealed class AsyncConsumerDataProcessor : IAsyncConsumerDataProcessor
{
private readonly ITask _task;
private readonly CancellationToken _cancellationToken;
private readonly SingleConsumerUnboundedChannel<(IDataProducer DataProducer, IData Data)> _channel = new();

// This is needed to avoid possible race condition between drain and _totalPayloadProcessed race condition.
// This is the "logical" consume workflow state.
private readonly TaskCompletionSource<object> _consumerState = new();
private readonly SingleConsumerUnboundedChannel<AsyncConsumerDataProcessorMessage> _channel = new();
private readonly Task _consumeTask;
private long _totalPayloadReceived;
private long _totalPayloadProcessed;

// Number of data payloads enqueued via PublishAsync. The message bus reads this via
// ReceivedCount to detect publisher/consumer cycles across drain rounds.
private long _receivedCount;

public AsyncConsumerDataProcessor(IDataConsumer dataConsumer, ITask task, CancellationToken cancellationToken)
{
DataConsumer = dataConsumer;
_task = task;
_cancellationToken = cancellationToken;
_consumeTask = task.Run(ConsumeAsync, cancellationToken);
}

public IDataConsumer DataConsumer { get; }

public long ReceivedCount => Volatile.Read(ref _receivedCount);

public Task PublishAsync(IDataProducer dataProducer, IData data)
{
_cancellationToken.ThrowIfCancellationRequested();
Interlocked.Increment(ref _totalPayloadReceived);
_channel.Write((dataProducer, data));
Interlocked.Increment(ref _receivedCount);
_channel.Write(AsyncConsumerDataProcessorMessage.CreateData(dataProducer, data));
return Task.CompletedTask;
}

Expand All @@ -45,60 +44,31 @@ private async Task ConsumeAsync()
{
while (await _channel.WaitToReadAsync(_cancellationToken).ConfigureAwait(false))
{
while (_channel.TryRead(out (IDataProducer DataProducer, IData Data) item))
while (_channel.TryRead(out AsyncConsumerDataProcessorMessage message))
{
try
if (message.DrainMarker is { } drainMarker)
{
// We don't enqueue the data if the consumer is the producer of the data.
// We could optimize this if and make a get with type/all but producers, but it
// could be over-engineering.
if (item.DataProducer.Uid == DataConsumer.Uid)
{
continue;
}

try
{
await DataConsumer.ConsumeAsync(item.DataProducer, item.Data, _cancellationToken).ConfigureAwait(false);
}

// We let the catch below to handle the graceful cancellation of the process
catch (Exception ex) when (ex is not OperationCanceledException)
{
// If we're draining before to increment the _totalPayloadProcessed we need to signal that we should throw because
// it's possible we have a race condition where the payload check at line 106 return false and the current task is not yet in a
// "faulted state".
_consumerState.SetException(ex);

// We let current task to move to fault state, checked inside CompleteAddingAsync.
throw;
}
// The drain marker passed all the data items previously enqueued so we can signal the drain caller.
drainMarker.TrySetResult(true);
continue;
}
finally

// We don't enqueue the data if the consumer is the producer of the data.
// We could optimize this if and make a get with type/all but producers, but it
// could be over-engineering.
if (message.DataProducer!.Uid == DataConsumer.Uid)
{
Interlocked.Increment(ref _totalPayloadProcessed);
continue;
}

await DataConsumer.ConsumeAsync(message.DataProducer, message.Data!, _cancellationToken).ConfigureAwait(false);
}
}
}
catch (OperationCanceledException oc) when (oc.CancellationToken == _cancellationToken)
{
// Ignore we're shutting down
}
catch (Exception ex)
{
// For all other exception we signal the state if not already faulted
if (!_consumerState.Task.IsFaulted)
{
_consumerState.SetException(ex);
}

// let the exception bubble up
throw;
}

// We're exiting gracefully, signal the correct state.
_consumerState.SetResult(new object());
}

public async Task CompleteAddingAsync()
Expand All @@ -111,43 +81,37 @@ public async Task CompleteAddingAsync()
await _consumeTask.ConfigureAwait(false);
}

public async Task<long> DrainDataAsync()
public async Task DrainDataAsync()
{
// We go volatile because we race with Interlocked.Increment in PublishAsync
long totalPayloadProcessed = Volatile.Read(ref _totalPayloadProcessed);
long totalPayloadReceived = Volatile.Read(ref _totalPayloadReceived);
const int minDelayTimeMs = 25;
int currentDelayTimeMs = minDelayTimeMs;
while (Interlocked.CompareExchange(ref _totalPayloadReceived, totalPayloadReceived, totalPayloadProcessed) != totalPayloadProcessed)
{
// When we cancel we throw inside ConsumeAsync and we won't drain anymore any data
if (_cancellationToken.IsCancellationRequested)
{
break;
}

await _task.Delay(currentDelayTimeMs).ConfigureAwait(false);
currentDelayTimeMs = Math.Min(currentDelayTimeMs + minDelayTimeMs, 200);
var drainMarker = new TaskCompletionSource<bool>(TaskCreationOptions.RunContinuationsAsynchronously);

if (_consumerState.Task.IsFaulted)
{
// Rethrow the exception
await _consumerState.Task.ConfigureAwait(false);
}

// Wait for the consumer to complete the current enqueued items
totalPayloadProcessed = Volatile.Read(ref _totalPayloadProcessed);
totalPayloadReceived = Volatile.Read(ref _totalPayloadReceived);
try
{
_channel.Write(AsyncConsumerDataProcessorMessage.CreateDrainMarker(drainMarker));
}

// It' possible that we fail and we have consumed the item
if (_consumerState.Task.IsFaulted)
catch (InvalidOperationException)
{
// Rethrow the exception
await _consumerState.Task.ConfigureAwait(false);
// The channel was already completed (e.g., by DisableAsync). Nothing left to drain.
return;
}
catch (OperationCanceledException oc) when (oc.CancellationToken == _cancellationToken)
{
// The application is shutting down. Treat the drain as a graceful no-op,
// matching the previous behavior of bailing out of DrainDataAsync on cancellation.
return;
}

return _totalPayloadReceived;
// Wait either for the drain marker to be dequeued, or for the consume task to finish/fault.
// If the consume task ends before the marker is reached, propagate any failure it surfaced.
Task completed = await Task.WhenAny(drainMarker.Task, _consumeTask).ConfigureAwait(false);
if (completed == _consumeTask)
{
await _consumeTask.ConfigureAwait(false);
}
else
{
await drainMarker.Task.ConfigureAwait(false);
}
}

public void Dispose()
Expand Down
Loading
Loading