Fork/Join 分治模式
一道面试题:如何用最高效的方式对 1000 万个整数排序?
如果单线程归并排序,时间复杂度是 O(n log n),但只能用一个核心。如果能把数据分成多份,用多核并行排序,最后再合并结果,时间会大大缩短。
Fork/Join 框架就是为这种分治并行场景设计的。
ForkJoinPool vs ThreadPoolExecutor
JDK 提供了两种线程池,为什么还要 ForkJoinPool?
flowchart LR
subgraph ThreadPoolExecutor
A1[Thread 1] --> Q1[共享队列]
A2[Thread 2] --> Q1
A3[Thread 3] --> Q1
end
subgraph ForkJoinPool
B1[Thread 1] --> Q2[队列1]
B2[Thread 2] --> Q3[队列2]
B3[Thread 3] --> Q4[队列3]
Q4 -.->|窃取| Q2
end
ForkJoinPool 的核心是工作窃取算法:每个线程有自己的任务队列,任务执行时会「分叉」(fork)出子任务放回队列。空闲的线程可以从其他线程队列尾部「偷」任务来执行。
工作窃取算法
为什么工作窃取比共享队列更好?
共享队列的问题:多个线程竞争同一个队列,需要加锁。在高并发下,锁竞争会成为瓶颈。
工作窃取的优势:
- 减少锁竞争:大部分操作是线程自己的队列,只有窃取时才需要同步
- 负载均衡:忙的线程不断产生子任务,闲的线程去偷,自动平衡
- 提高缓存命中率:窃取时从队列尾部拿,和拥有线程操作的是不同位置,减少伪共享
// ForkJoinPool 的工作窃取示意
if (localQueue.isNotEmpty()) {
task = localQueue.pop(); // 自己的队列,直接 pop
} else {
// 队列空了,去偷别人的
task = stealFromOtherQueue();
}
RecursiveTask vs RecursiveAction
ForkJoinPool 有两种任务类型:
RecursiveTask<V>:有返回值
RecursiveAction:无返回值
// 计算数组之和,有返回值
public class SumTask extends RecursiveTask<Long> {
private static final int THRESHOLD = 10000;
private final long[] array;
private final int start;
private final int end;
public SumTask(long[] array, int start, int end) {
this.array = array;
this.start = start;
this.end = end;
}
@Override
protected Long compute() {
// 任务足够小,直接计算
if (end - start <= THRESHOLD) {
long sum = 0;
for (int i = start; i < end; i++) {
sum += array[i];
}
return sum;
}
// 拆分任务
int mid = (start + end) / 2;
SumTask left = new SumTask(array, start, mid);
SumTask right = new SumTask(array, mid, end);
// fork:异步执行左半部分
left.fork();
// join:等待并获取结果,同时当前线程计算右半部分
Long rightResult = right.compute();
Long leftResult = left.join();
return leftResult + rightResult;
}
}
// 无返回值的任务:打印数组
public class PrintTask extends RecursiveAction {
private static final int THRESHOLD = 1000;
private final int[] array;
private final int start;
private final int end;
@Override
protected void compute() {
if (end - start <= THRESHOLD) {
for (int i = start; i < end; i++) {
System.out.println(array[i]);
}
return;
}
int mid = (start + end) / 2;
invokeAll(
new PrintTask(array, start, mid),
new PrintTask(array, mid, end)
);
}
}
关键点:fork() 异步执行子任务,join() 阻塞等待结果。调用 fork() 后,当前线程继续计算右半部分,这样右半部分的计算和左半部分的计算是并行的。
fork() + join() 原理
ForkJoinPool 的执行流程:
- Fork:将大任务拆成小任务,提交到队列
- Join:等待子任务完成,合并结果
- Work-Stealing:空闲线程从其他线程队列尾部偷任务
sequenceDiagram
participant Main as 主线程
participant Pool as ForkJoinPool
participant T1 as Worker 1
participant T2 as Worker 2
Main->>Pool: fork(Task A) - 拆成 A1 + A2
Pool->>T1: 执行 A1
T1->>T1: fork(A1.1)
T1->>T1: compute(A1.2) 直接执行
T1->>T1: join(A1.1) 等待结果
T2->>T2: steal(A1.1) 从 T1 队列偷任务
T2->>T2: compute(A1.1) 执行
T1->>Pool: 返回 A1 结果
Pool->>Main: 合并所有结果
Fork/Join 的适用场景
Fork/Join 适合满足以下条件的任务:
- 可分解:任务可以递归拆分成更小的子任务
- 独立性强:子任务之间没有共享数据的写操作(只读也可以)
- 计算密集型:CPU 密集的计算任务,而非大量 I/O 等待
典型应用场景:
- 归并排序、QuickSort:将数组分成多段并行排序,最后合并
- 矩阵运算:矩阵乘法、转置等可以分块并行
- 树形结构的遍历和聚合:如文件大小统计、DOM 树遍历
- 并行流(parallelStream):JDK 的并行流底层就是 ForkJoinPool
归并排序并行化
public class ParallelMergeSort extends RecursiveTask<int[]> {
private static final int THRESHOLD = 1024;
private final int[] array;
private final int left;
private final int right;
public ParallelMergeSort(int[] array) {
this(array, 0, array.length);
}
private ParallelMergeSort(int[] array, int left, int right) {
this.array = array;
this.left = left;
this.right = right;
}
@Override
protected int[] compute() {
int length = right - left;
if (length <= THRESHOLD) {
return Arrays.copyOfRange(array, left, right);
}
int mid = left + length / 2;
ParallelMergeSort leftTask = new ParallelMergeSort(array, left, mid);
ParallelMergeSort rightTask = new ParallelMergeSort(array, mid, right);
leftTask.fork();
int[] rightResult = rightTask.compute();
int[] leftResult = leftTask.join();
return merge(leftResult, rightResult);
}
private int[] merge(int[] left, int[] right) {
int[] result = new int[left.length + right.length];
int i = 0, j = 0, k = 0;
while (i < left.length && j < right.length) {
result[k++] = left[i] <= right[j] ? left[i++] : right[j++];
}
while (i < left.length) result[k++] = left[i++];
while (j < right.length) result[k++] = right[j++];
return result;
}
}
并行流(parallelStream)底层使用 ForkJoinPool
Java 8 的 parallelStream 是 ForkJoinPool 的上层封装,使用起来非常简单:
// 串行流
List<Long> result = list.stream()
.map(this::heavyComputation)
.collect(Collectors.toList());
// 并行流(自动使用 ForkJoinPool)
List<Long> parallelResult = list.parallelStream()
.map(this::heavyComputation)
.collect(Collectors.toList());
默认使用公共线程池:ForkJoinPool.commonPool(),默认大小是 Runtime.getRuntime().availableProcessors() - 1。
问题:所有并行流共享同一个池,可能产生竞争。
解决方案:为特定任务创建专用的 ForkJoinPool:
ForkJoinPool pool = new ForkJoinPool(
Runtime.getRuntime().availableProcessors() * 2
);
long result = pool.submit(() ->
list.parallelStream()
.mapToLong(this::heavyComputation)
.sum()
).join();
pool.shutdown();
:::warning
parallelStream 看起来很美好,但有几个坑:
- 任务必须可分解。如果数据源是
LinkedList,分解会很慢,因为它需要随机访问。
- 避免改变共享状态。并行流的任务应该是纯函数。
- 小心装箱。
parallelStream().map(Long::parseLong) 会大量装箱,建议用原始类型流。
- 执行顺序不确定。如果依赖顺序,用
forEachOrdered() 而不是 forEach()。
:::
ForkJoinPool 参数调优
ForkJoinPool pool = new ForkJoinPool(
parallelism, // 并行度,默认 CPU 核心数
factory, // 线程工厂
handler, // 异常处理器
asyncMode // true: 异步模式,适合事件型任务
);
parallelism:期望的并发线程数。通常设为 CPU 核心数。
asyncMode:如果任务大多是短暂的小任务(如事件处理),设为 true。如果任务较长且需要返回结果,设为 false(默认)。
// 推荐配置:并行度设为 CPU 核心数
ForkJoinPool pool = new ForkJoinPool(
Runtime.getRuntime().availableProcessors()
);
总结与延伸
Fork/Join 是分治策略在并行计算中的完美实现:
核心优势:
- 工作窃取算法,最大化 CPU 利用率
- 自动负载均衡
- 适合递归分治任务
适用场景:
- 归并排序、快速排序
- 大数组/大文件的并行处理
- 分形计算、Mandelbrot 集合
- JDK parallelStream 底层实现
注意事项:
- 任务分解要有终止条件,避免无限拆分
- 避免在任务中执行阻塞操作,否则会浪费工作线程
- 合理设置阈值,分解成本不能超过合并收益
- 线程池不要创建过多,通常一个应用一个 ForkJoinPool 即可
那么问题来了:ForkJoinPool 的工作窃取是从队列头部还是尾部拿任务?为什么这样设计?这涉及到伪共享(False Sharing)问题——从尾部窃取可以最大化缓存命中率。