Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@ internal static class EnvironmentVariableConstants
public const string DOTNET_WATCH = nameof(DOTNET_WATCH);
public const string TESTINGPLATFORM_HOTRELOAD_ENABLED = nameof(TESTINGPLATFORM_HOTRELOAD_ENABLED);
public const string TESTINGPLATFORM_DEFAULT_HANG_TIMEOUT = nameof(TESTINGPLATFORM_DEFAULT_HANG_TIMEOUT);
public const string TESTINGPLATFORM_MESSAGEBUS_DRAINDATA_ATTEMPTS = nameof(TESTINGPLATFORM_MESSAGEBUS_DRAINDATA_ATTEMPTS);

public const string TESTINGPLATFORM_TESTHOSTCONTROLLER_SKIPEXTENSION = nameof(TESTINGPLATFORM_TESTHOSTCONTROLLER_SKIPEXTENSION);
public const string TESTINGPLATFORM_TESTHOSTCONTROLLER_PIPENAME = nameof(TESTINGPLATFORM_TESTHOSTCONTROLLER_PIPENAME);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -110,8 +110,7 @@ private static async Task<ITestFramework> BuildTestFrameworkAsync(TestFrameworkB
[.. dataConsumersBuilder],
serviceProvider.GetTestApplicationCancellationTokenSource(),
serviceProvider.GetTask(),
serviceProvider.GetLoggerFactory(),
serviceProvider.GetEnvironment());
serviceProvider.GetLoggerFactory());
await concreteMessageBusService.InitAsync().ConfigureAwait(false);
testFrameworkBuilderData.MessageBusProxy.SetBuiltMessageBus(concreteMessageBusService);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -170,12 +170,11 @@ protected override async Task<int> InternalRunAsync(CancellationToken cancellati
}
}

AsynchronousMessageBus concreteMessageBusService = new(
var concreteMessageBusService = new AsynchronousMessageBus(
[.. dataConsumersBuilder],
ServiceProvider.GetTestApplicationCancellationTokenSource(),
ServiceProvider.GetTask(),
ServiceProvider.GetLoggerFactory(),
ServiceProvider.GetEnvironment());
ServiceProvider.GetLoggerFactory());
await concreteMessageBusService.InitAsync().ConfigureAwait(false);
((MessageBusProxy)ServiceProvider.GetMessageBus()).SetBuiltMessageBus(concreteMessageBusService);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,9 @@ namespace Microsoft.Testing.Platform.Messages;
[DebuggerDisplay("DataConsumer = {DataConsumer.Uid}")]
internal sealed class AsyncConsumerDataProcessor : IAsyncConsumerDataProcessor
{
private readonly ITask _task;
private readonly CancellationToken _cancellationToken;
private readonly Channel<(IDataProducer DataProducer, IData Data)> _channel = Channel.CreateUnbounded<(IDataProducer DataProducer, IData Data)>(new UnboundedChannelOptions

private readonly Channel<AsyncConsumerDataProcessorMessage> _channel = Channel.CreateUnbounded<AsyncConsumerDataProcessorMessage>(new UnboundedChannelOptions
{
// We process only 1 data at a time
SingleReader = true,
Expand All @@ -27,88 +27,56 @@ internal sealed class AsyncConsumerDataProcessor : IAsyncConsumerDataProcessor
AllowSynchronousContinuations = false,
});

// This is needed to avoid possible race condition between drain and _totalPayloadProcessed race condition.
// This is the "logical" consume workflow state.
private readonly TaskCompletionSource _consumerState = new();
private readonly Task _consumeTask;
private long _totalPayloadReceived;
private long _totalPayloadProcessed;

// Number of data payloads dequeued by the consumer. Used by DrainDataAsync to detect publisher/consumer loops.
// Only the single consumer task increments this field; other threads read it via Volatile.Read.
private long _processedCount;

public AsyncConsumerDataProcessor(IDataConsumer consumer, ITask task, CancellationToken cancellationToken)
{
DataConsumer = consumer;
_task = task;
_cancellationToken = cancellationToken;
_consumeTask = task.Run(ConsumeAsync, cancellationToken);
}

public IDataConsumer DataConsumer { get; }

public async Task PublishAsync(IDataProducer dataProducer, IData data)
{
Interlocked.Increment(ref _totalPayloadReceived);
await _channel.Writer.WriteAsync((dataProducer, data), _cancellationToken).ConfigureAwait(false);
}
=> await _channel.Writer.WriteAsync(AsyncConsumerDataProcessorMessage.CreateData(dataProducer, data), _cancellationToken).ConfigureAwait(false);

private async Task ConsumeAsync()
{
try
{
while (await _channel.Reader.WaitToReadAsync(_cancellationToken).ConfigureAwait(false))
{
(IDataProducer dataProducer, IData data) = await _channel.Reader.ReadAsync(_cancellationToken).ConfigureAwait(false);
AsyncConsumerDataProcessorMessage message = await _channel.Reader.ReadAsync(_cancellationToken).ConfigureAwait(false);

try
if (message.DrainMarker is { } drainMarker)
{
// We don't enqueue the data if the consumer is the producer of the data.
// We could optimize this if and make a get with type/all but producers, but it
// could be over-engineering.
if (dataProducer.Uid == DataConsumer.Uid)
{
continue;
}

try
{
await DataConsumer.ConsumeAsync(dataProducer, data, _cancellationToken).ConfigureAwait(false);
}

// We let the catch below to handle the graceful cancellation of the process
catch (Exception ex) when (ex is not OperationCanceledException)
{
// If we're draining before to increment the _totalPayloadProcessed we need to signal that we should throw because
// it's possible we have a race condition where the payload counting in DrainDataAsync returns false and the current task is not yet in a
// "faulted state".
_consumerState.SetException(ex);

// We let current task to move to fault state, checked inside CompleteAddingAsync.
throw;
}
// The drain marker passed all the data items previously enqueued so we can signal the drain caller.
drainMarker.TrySetResult(true);
continue;
}
finally

Interlocked.Increment(ref _processedCount);

// We don't enqueue the data if the consumer is the producer of the data.
// We could optimize this if and make a get with type/all but producers, but it
// could be over-engineering.
if (message.DataProducer!.Uid == DataConsumer.Uid)
{
Interlocked.Increment(ref _totalPayloadProcessed);
continue;
}

await DataConsumer.ConsumeAsync(message.DataProducer, message.Data!, _cancellationToken).ConfigureAwait(false);
}
}
catch (OperationCanceledException oc) when (oc.CancellationToken == _cancellationToken)
{
// Ignore we're shutting down
}
catch (Exception ex)
{
// For all other exception we signal the state if not already faulted
if (!_consumerState.Task.IsFaulted)
{
_consumerState.SetException(ex);
}

// let the exception bubble up
throw;
}

// We're exiting gracefully, signal the correct state.
_consumerState.SetResult();
}

public async Task CompleteAddingAsync()
Expand All @@ -121,43 +89,34 @@ public async Task CompleteAddingAsync()
await _consumeTask.ConfigureAwait(false);
}

public async Task<long> DrainDataAsync()
public async Task<bool> DrainDataAsync()
{
// We go volatile because we race with Interlocked.Increment in PublishAsync
long totalPayloadProcessed = Volatile.Read(ref _totalPayloadProcessed);
long totalPayloadReceived = Volatile.Read(ref _totalPayloadReceived);
const int minDelayTimeMs = 25;
int currentDelayTimeMs = minDelayTimeMs;
while (Interlocked.CompareExchange(ref _totalPayloadReceived, totalPayloadReceived, totalPayloadProcessed) != totalPayloadProcessed)
{
// When we cancel we throw inside ConsumeAsync and we won't drain anymore any data
if (_cancellationToken.IsCancellationRequested)
{
break;
}

await _task.Delay(currentDelayTimeMs).ConfigureAwait(false);
currentDelayTimeMs = Math.Min(currentDelayTimeMs + minDelayTimeMs, 200);

if (_consumerState.Task.IsFaulted)
{
// Rethrow the exception
await _consumerState.Task.ConfigureAwait(false);
}
long before = Volatile.Read(ref _processedCount);
var drainMarker = new TaskCompletionSource<bool>(TaskCreationOptions.RunContinuationsAsynchronously);

// Wait for the consumer to complete the current enqueued items
totalPayloadProcessed = Volatile.Read(ref _totalPayloadProcessed);
totalPayloadReceived = Volatile.Read(ref _totalPayloadReceived);
try
{
await _channel.Writer.WriteAsync(AsyncConsumerDataProcessorMessage.CreateDrainMarker(drainMarker), _cancellationToken).ConfigureAwait(false);
}
catch (ChannelClosedException)
{
// The channel was already completed (e.g., by DisableAsync). Nothing left to drain.
return false;
}

// It' possible that we fail and we have consumed the item
if (_consumerState.Task.IsFaulted)
// Wait either for the drain marker to be dequeued, or for the consume task to finish/fault.
// If the consume task ends before the marker is reached, propagate any failure it surfaced.
Task completed = await Task.WhenAny(drainMarker.Task, _consumeTask).ConfigureAwait(false);
if (completed == _consumeTask)
{
await _consumeTask.ConfigureAwait(false);
}
else
{
// Rethrow the exception
await _consumerState.Task.ConfigureAwait(false);
await drainMarker.Task.ConfigureAwait(false);
}

return _totalPayloadReceived;
return Volatile.Read(ref _processedCount) != before;
}
Comment on lines +95 to 126
Copy link

Copilot AI Feb 20, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Critical race condition: When DrainDataAsync completes the channel writer (line 72) and before creating a new channel (line 75), any concurrent calls to PublishAsync will throw ChannelClosedException when trying to write to the completed channel. This is problematic because DrainDataAsync is called at multiple synchronization points during normal execution (see CommonTestHost.cs lines 223, 229, 245, 249), not just during shutdown. The old implementation avoided this by not completing the channel during drain. Consider using a lock or other synchronization mechanism to atomically swap the old channel with a new one, or ensure no publishing can occur during drain.

Copilot uses AI. Check for mistakes.

// At this point we simply signal the channel as complete and we don't wait for the consumer to complete.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,21 +10,18 @@ namespace Microsoft.Testing.Platform.Messages;

internal sealed class AsyncConsumerDataProcessor : IAsyncConsumerDataProcessor
{
private readonly ITask _task;
private readonly CancellationToken _cancellationToken;
private readonly SingleConsumerUnboundedChannel<(IDataProducer DataProducer, IData Data)> _channel = new();

// This is needed to avoid possible race condition between drain and _totalPayloadProcessed race condition.
// This is the "logical" consume workflow state.
private readonly TaskCompletionSource<object> _consumerState = new();
private readonly SingleConsumerUnboundedChannel<AsyncConsumerDataProcessorMessage> _channel = new();
private readonly Task _consumeTask;
private long _totalPayloadReceived;
private long _totalPayloadProcessed;

// Number of data payloads dequeued by the consumer. Used by DrainDataAsync to detect publisher/consumer loops.
// Only the single consumer task increments this field; other threads read it via Volatile.Read.
private long _processedCount;

public AsyncConsumerDataProcessor(IDataConsumer dataConsumer, ITask task, CancellationToken cancellationToken)
{
DataConsumer = dataConsumer;
_task = task;
_cancellationToken = cancellationToken;
_consumeTask = task.Run(ConsumeAsync, cancellationToken);
}
Expand All @@ -34,8 +31,7 @@ public AsyncConsumerDataProcessor(IDataConsumer dataConsumer, ITask task, Cancel
public Task PublishAsync(IDataProducer dataProducer, IData data)
{
_cancellationToken.ThrowIfCancellationRequested();
Interlocked.Increment(ref _totalPayloadReceived);
_channel.Write((dataProducer, data));
_channel.Write(AsyncConsumerDataProcessorMessage.CreateData(dataProducer, data));
return Task.CompletedTask;
}

Expand All @@ -45,60 +41,33 @@ private async Task ConsumeAsync()
{
while (await _channel.WaitToReadAsync(_cancellationToken).ConfigureAwait(false))
{
while (_channel.TryRead(out (IDataProducer DataProducer, IData Data) item))
while (_channel.TryRead(out AsyncConsumerDataProcessorMessage message))
{
try
if (message.DrainMarker is { } drainMarker)
{
// We don't enqueue the data if the consumer is the producer of the data.
// We could optimize this if and make a get with type/all but producers, but it
// could be over-engineering.
if (item.DataProducer.Uid == DataConsumer.Uid)
{
continue;
}

try
{
await DataConsumer.ConsumeAsync(item.DataProducer, item.Data, _cancellationToken).ConfigureAwait(false);
}

// We let the catch below to handle the graceful cancellation of the process
catch (Exception ex) when (ex is not OperationCanceledException)
{
// If we're draining before to increment the _totalPayloadProcessed we need to signal that we should throw because
// it's possible we have a race condition where the payload check at line 106 return false and the current task is not yet in a
// "faulted state".
_consumerState.SetException(ex);

// We let current task to move to fault state, checked inside CompleteAddingAsync.
throw;
}
// The drain marker passed all the data items previously enqueued so we can signal the drain caller.
drainMarker.TrySetResult(true);
continue;
}
finally

Interlocked.Increment(ref _processedCount);

// We don't enqueue the data if the consumer is the producer of the data.
// We could optimize this if and make a get with type/all but producers, but it
// could be over-engineering.
if (message.DataProducer!.Uid == DataConsumer.Uid)
{
Interlocked.Increment(ref _totalPayloadProcessed);
continue;
}

await DataConsumer.ConsumeAsync(message.DataProducer, message.Data!, _cancellationToken).ConfigureAwait(false);
}
}
}
catch (OperationCanceledException oc) when (oc.CancellationToken == _cancellationToken)
{
// Ignore we're shutting down
}
catch (Exception ex)
{
// For all other exception we signal the state if not already faulted
if (!_consumerState.Task.IsFaulted)
{
_consumerState.SetException(ex);
}

// let the exception bubble up
throw;
}

// We're exiting gracefully, signal the correct state.
_consumerState.SetResult(new object());
}

public async Task CompleteAddingAsync()
Expand All @@ -111,43 +80,34 @@ public async Task CompleteAddingAsync()
await _consumeTask.ConfigureAwait(false);
}

public async Task<long> DrainDataAsync()
public async Task<bool> DrainDataAsync()
{
// We go volatile because we race with Interlocked.Increment in PublishAsync
long totalPayloadProcessed = Volatile.Read(ref _totalPayloadProcessed);
long totalPayloadReceived = Volatile.Read(ref _totalPayloadReceived);
const int minDelayTimeMs = 25;
int currentDelayTimeMs = minDelayTimeMs;
while (Interlocked.CompareExchange(ref _totalPayloadReceived, totalPayloadReceived, totalPayloadProcessed) != totalPayloadProcessed)
{
// When we cancel we throw inside ConsumeAsync and we won't drain anymore any data
if (_cancellationToken.IsCancellationRequested)
{
break;
}

await _task.Delay(currentDelayTimeMs).ConfigureAwait(false);
currentDelayTimeMs = Math.Min(currentDelayTimeMs + minDelayTimeMs, 200);

if (_consumerState.Task.IsFaulted)
{
// Rethrow the exception
await _consumerState.Task.ConfigureAwait(false);
}
long before = Volatile.Read(ref _processedCount);
var drainMarker = new TaskCompletionSource<bool>(TaskCreationOptions.RunContinuationsAsynchronously);

// Wait for the consumer to complete the current enqueued items
totalPayloadProcessed = Volatile.Read(ref _totalPayloadProcessed);
totalPayloadReceived = Volatile.Read(ref _totalPayloadReceived);
try
{
_channel.Write(AsyncConsumerDataProcessorMessage.CreateDrainMarker(drainMarker));
}
catch (InvalidOperationException)
{
// The channel was already completed (e.g., by DisableAsync). Nothing left to drain.
return false;
}

// It' possible that we fail and we have consumed the item
if (_consumerState.Task.IsFaulted)
// Wait either for the drain marker to be dequeued, or for the consume task to finish/fault.
// If the consume task ends before the marker is reached, propagate any failure it surfaced.
Task completed = await Task.WhenAny(drainMarker.Task, _consumeTask).ConfigureAwait(false);
if (completed == _consumeTask)
{
await _consumeTask.ConfigureAwait(false);
}
else
{
// Rethrow the exception
await _consumerState.Task.ConfigureAwait(false);
await drainMarker.Task.ConfigureAwait(false);
}

return _totalPayloadReceived;
return Volatile.Read(ref _processedCount) != before;
}
Comment on lines +84 to 115
Copy link

Copilot AI Feb 20, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Critical race condition: When DrainDataAsync completes the channel (line 74) and before creating a new channel (line 77), any concurrent calls to PublishAsync will throw InvalidOperationException when trying to write to the completed channel. This is problematic because DrainDataAsync is called at multiple synchronization points during normal execution (see CommonTestHost.cs lines 223, 229, 245, 249), not just during shutdown. The old implementation avoided this by not completing the channel during drain. Consider using a lock or other synchronization mechanism to atomically swap the old channel with a new one, or ensure no publishing can occur during drain.

Copilot uses AI. Check for mistakes.

public void Dispose()
Expand Down
Loading
Loading