在PLINQ中绑定源线程

wkyowqbh  于 2023-01-22  发布在  其他
关注(0)|答案(2)|浏览(138)

我有一个使用PLINQ并行化的计算,如下所示:

  • IEnumerable<T> source正在提供从文件读取的对象。
  • 我需要在每个T上执行一个重量级计算HeavyComputation,我希望这些计算跨线程进行,因此我使用PLINQ,如下所示:AsParallel().Select(HeavyComputation)

有趣的是:由于对提供source的文件读取器类型的限制,我需要在初始线程上枚举source,而不是在并行工作线程上。我需要将source的完整求值 * 绑定 * 到主线程。然而,看起来源代码实际上是在工作线程上枚举的。
我的问题是:是否有一种直接的方法可以修改此代码,将source的枚举绑定到初始线程,同时将繁重的工作分配给并行工作者?请记住,在AsParallel()之前执行急切的.ToList()不是这里的选项,因为来自文件的数据流非常庞大。
下面是一些示例代码,演示了我所看到的问题:

using System.Threading;
using System.Collections.Generic;
using System.Linq;
using System;

public class PlinqTest
{
    private static string FormatItems<T>(IEnumerable<T> source)
    {
            return String.Format("[{0}]", String.Join(";", source));
    }

    public static void Main()
    {
        var expectedThreadIds = new[] { Thread.CurrentThread.ManagedThreadId };

        var threadIds = Enumerable.Range(1, 1000)
                .Select(x => Thread.CurrentThread.ManagedThreadId) // (1)
                .AsParallel()
                .WithDegreeOfParallelism(8)
                .WithExecutionMode(ParallelExecutionMode.ForceParallelism)
                .AsOrdered()
                .Select(x => x)                                    // (2)
                .ToArray();

        // In the computation above, the lambda in (1) is a
        // stand in for the file-reading operation that we
        // want to be bound to the main thread, while the
        // lambda in (2) is a stand-in for the "expensive
        // computation" that we want to be farmed out to the
        // parallel worker threads.  In fact, (1) is being
        // executed on all threads, as can be seen from the
        // output.

        Console.WriteLine("Expected thread IDs: {0}",
                          FormatItems(expectedThreadIds));
        Console.WriteLine("Found thread IDs: {0}",
                          FormatItems(threadIds.Distinct()));
    }
}

我得到的示例输出是:

Expected thread IDs: [1]
Found thread IDs: [7;4;8;6;11;5;10;9]
3b6akqbq

3b6akqbq1#

如果放弃PLINQ而仅显式使用Task Parallel Library,这将非常简单(尽管可能不那么简洁):

// Limits the parallelism of the "expensive task"
var semaphore = new SemaphoreSlim(8);

var tasks = Enumerable.Range(1, 1000)
    .Select(x => Thread.CurrentThread.ManagedThreadId)
    .Select(async x =>
    {
        await semaphore.WaitAsync();
        var result = await Task.Run(() => Tuple.Create(x, Thread.CurrentThread.ManagedThreadId));
        semaphore.Release();

        return result;
    });

return Task.WhenAll(tasks).Result;

注意,我使用Tuple.Create来记录来自主线程的线程ID和来自派生任务的线程ID,从我的测试来看,前者对于每个元组总是相同的,而后者是变化的,这是应该的。
信号量确保并行度永远不会超过8(尽管创建元组的任务成本很低,但这不太可能)。如果达到8,任何新任务都将等待,直到信号量上有可用的位置。

omtl5h9j

omtl5h9j2#

您可以使用下面的OffloadQueryEnumeration方法,该方法可确保源序列的枚举将在枚举结果IEnumerable<TResult>的同一线程上进行。querySelector是一个委托,它将源序列的代理转换为ParallelQuery<T>。此查询在ThreadPool线程上内部枚举,但输出值将返回到当前线程。

/// <summary>
/// Enumerates the source sequence on the current thread, and enumerates
/// the projected query on a ThreadPool thread.
/// </summary>
public static IEnumerable<TResult> OffloadQueryEnumeration<TSource, TResult>(
    this IEnumerable<TSource> source,
    Func<IEnumerable<TSource>, IEnumerable<TResult>> querySelector)
{
    ArgumentNullException.ThrowIfNull(source);
    ArgumentNullException.ThrowIfNull(querySelector);
    object locker = new();
    (TSource Value, bool HasValue) input = default; bool inputCompleted = false;
    (TResult Value, bool HasValue) output = default; bool outputCompleted = false;
    using IEnumerator<TSource> sourceEnumerator = source.GetEnumerator();

    IEnumerable<TSource> GetSourceProxy()
    {
        while (true)
        {
            TSource sourceItem;
            lock (locker)
            {
                while (true)
                {
                    if (inputCompleted || outputCompleted) yield break;
                    if (input.HasValue) break;
                    Monitor.Wait(locker);
                }
                sourceItem = input.Value;
                input = default; Monitor.PulseAll(locker);
            }
            yield return sourceItem;
        }
    }

    IEnumerable<TResult> query = querySelector(GetSourceProxy());
    Task outputReaderTask = Task.Run(() =>
    {
        try
        {
            foreach (TResult result in query)
            {
                lock (locker)
                {
                    while (true)
                    {
                        if (outputCompleted) return;
                        if (!output.HasValue) break;
                        Monitor.Wait(locker);
                    }
                    output = (result, true); Monitor.PulseAll(locker);
                }
            }
        }
        finally
        {
            lock (locker) { outputCompleted = true; Monitor.PulseAll(locker); }
        }
    });

    // Main loop
    List<Exception> exceptions = new();
    while (true)
    {
        TResult resultItem;
        lock (locker)
        {
            // Inner loop
            while (true)
            {
                if (output.HasValue)
                {
                    resultItem = output.Value;
                    output = default; Monitor.PulseAll(locker);
                    goto yieldResult;
                }
                if (outputCompleted) goto exitMainLoop;
                if (!inputCompleted && !input.HasValue)
                {
                    // Fill the empty input slot, by reading the enumerator.
                    try
                    {
                        if (sourceEnumerator.MoveNext())
                            input = (sourceEnumerator.Current, true);
                        else
                            inputCompleted = true;
                    }
                    catch (Exception ex)
                    {
                        exceptions.Add(ex);
                        inputCompleted = true;
                    }
                    Monitor.PulseAll(locker); continue;
                }
                Monitor.Wait(locker);
            }
        }
    yieldResult:
        bool yieldOK = false;
        try { yield return resultItem; yieldOK = true; }
        finally
        {
            if (!yieldOK)
            {
                // The consumer stopped enumerating prematurely
                lock (locker) { outputCompleted = true; Monitor.PulseAll(locker); }
                Task.WhenAny(outputReaderTask).Wait();
            }
        }
    }
exitMainLoop:

    // Propagate possible exceptions
    try { outputReaderTask.GetAwaiter().GetResult(); }
    catch (OperationCanceledException) { throw; }
    catch (AggregateException aex) { exceptions.AddRange(aex.InnerExceptions); }

    if (exceptions.Count > 0)
        throw new AggregateException(exceptions);
}

此方法使用Monitor.Wait/Monitor.Pulse机制(教程),以便同步从一个线程到另一个线程的值传输。
用法示例:

int[] threadIds = Enumerable
    .Range(1, 1000)
    .Select(x => Thread.CurrentThread.ManagedThreadId)
    .OffloadQueryEnumeration(proxy => proxy
        .AsParallel()
        .AsOrdered()
        .WithDegreeOfParallelism(8)
        .WithExecutionMode(ParallelExecutionMode.ForceParallelism)
        .Select(x => x)
    )
    .ToArray();

Online demo.
OffloadQueryEnumeration是一个非常复杂的方法,它不停地处理三个线程:
1.既枚举source序列又使用PLINQ生成的元素的当前线程,在这两个操作之间交替。
1.枚举PLINQ生成的序列的ThreadPool线程(outputReaderTask)。
1.由PLINQ机制执行任务以从GetSourceProxy()迭代器获取下一项的工作线程。此线程并非始终相同,但在任何给定时刻,最多只有一个工作线程被分配此任务。
所以很多事情都在发生,隐藏的bug有很多机会不被发现地通过,这是一种需要编写十几个测试的API,来Assert许多可能的场景的正确性(例如源序列中的故障、PLINQ操作符中的故障、使用者中的故障、取消、放弃枚举等)。我已经手动测试了其中的一些场景,但是我还没有编写任何测试,所以使用这个方法要小心。

相关问题