Fork-Join

3/23/2022 线程池线程

# 简介

Fork/Join框架是Java并发工具包中的一种可以将一个大任务拆分为很多小任务来异步执行的工具,自JDK1.7引入。

Fork/Join框架主要包含三个模块:

  • 任务对象: ForkJoinTask (包括RecursiveTask、RecursiveAction 和 CountedCompleter)

  • 执行Fork/Join任务的线程: ForkJoinWorkerThread

  • 线程池: ForkJoinPool

这三者的关系是: ForkJoinPool可以通过池中的ForkJoinWorkerThread来处理ForkJoinTask任务。

ForkJoinPool 只接收 ForkJoinTask 任务(在实际使用中,也可以接收 Runnable/Callable 任务,但在真正运行时,也会把这些任务封装成 ForkJoinTask 类型的任务).

RecursiveTask 是 ForkJoinTask 的子类,是一个可以递归执行的 ForkJoinTask;RecursiveAction 是一个无返回值的 RecursiveTask;CountedCompleter 在任务完成执行后会触发执行一个自定义的钩子函数。

在实际运用中,我们一般都会继承 RecursiveTask 、RecursiveAction 或 CountedCompleter 来实现我们的业务需求,而不会直接继承 ForkJoinTask 类。

// from 《A Java Fork/Join Framework》Dong Lea
Result solve(Problem problem) {
	if (problem is small)
 		directly solve problem
 	else {
 		split problem into independent parts
 		fork new subtasks to solve each part
 		join all subtasks
 		compose result from subresults
	}
}

# 工作流程和核心思想

# 分支思想

img

# 工作窃取思想

  1. work-stealing(工作窃取)算法: 线程池内的所有工作线程都尝试找到并执行已经提交的任务,或者是被其他活动任务创建的子任务(如果不存在就阻塞等待)。这种特性使得 ForkJoinPool 在运行多个可以产生子任务的任务,或者是提交的许多小任务时效率更高。尤其是构建异步模型的 ForkJoinPool 时,对不需要合并(join)的事件类型任务也非常适用。

  2. 在 ForkJoinPool 中,线程池中每个工作线程(ForkJoinWorkerThread)都对应一个任务队列(WorkQueue),工作线程优先处理来自自身队列的任务(LIFO或FIFO顺序,参数 mode 决定),然后以FIFO的顺序随机窃取其他队列中的任务。

  3. 具体思路如下:

    • 每个线程都有自己的一个WorkQueue,该工作队列是一个双端队列。

    • 队列支持三个功能push、pop、poll

    • push/pop只能被队列的所有者线程调用,而poll可以被其他线程调用。

    • 划分的子任务调用fork时,都会被push到自己的队列中。

    • 默认情况下,工作线程从自己的双端队列获出任务并执行。

    • 当自己的队列为空时,线程随机从另一个线程的队列末尾调用poll方法窃取任务。

img

# 执行流程

  1. ForkJoinPool 中的任务执行分两种:

    • 直接通过 FJP 提交的外部任务(external/submissions task),存放在 workQueues 的偶数槽位;

    • 通过内部 fork 分割的子任务(Worker task),存放在 workQueues 的奇数槽位。

  2. Fork/Join 框架的执行流程如下:

img

# Fork/Join类关系

# ForkJoinPool继承关系

img

内部类介绍:

  • ForkJoinWorkerThreadFactory: 内部线程工厂接口,用于创建工作线程ForkJoinWorkerThread

  • DefaultForkJoinWorkerThreadFactory: ForkJoinWorkerThreadFactory 的默认实现类

  • InnocuousForkJoinWorkerThreadFactory: 实现了 ForkJoinWorkerThreadFactory,无许可线程工厂,当系统变量中有系统安全管理相关属性时,默认使用这个工厂创建工作线程。

  • EmptyTask: 内部占位类,用于替换队列中 join 的任务。

  • ManagedBlocker: 为 ForkJoinPool 中的任务提供扩展管理并行数的接口,一般用在可能会阻塞的任务(如在 Phaser 中用于等待 phase 到下一个generation)。

  • WorkQueue: ForkJoinPool 的核心数据结构,本质上是work-stealing 模式的双端任务队列,内部存放 ForkJoinTask 对象任务,使用 @Contented 注解修饰防止伪共享。

    • 工作线程在运行中产生新的任务(通常是因为调用了 fork())时,此时可以把 WorkQueue 的数据结构视为一个栈,新的任务会放入栈顶(top 位);工作线程在处理自己工作队列的任务时,按照 LIFO 的顺序。

    • 工作线程在处理自己的工作队列同时,会尝试窃取一个任务(可能是来自于刚刚提交到 pool 的任务,或是来自于其他工作线程的队列任务),此时可以把 WorkQueue 的数据结构视为一个 FIFO 的队列,窃取的任务位于其他线程的工作队列的队首(base位)。

  • 伪共享状态: 缓存系统中是以缓存行(cache line)为单位存储的。缓存行是2的整数幂个连续字节,一般为32-256个字节。最常见的缓存行大小是64个字节。当多线程修改互相独立的变量时,如果这些变量共享同一个缓存行,就会无意中影响彼此的性能,这就是伪共享。

# ForkJoinTask继承关系

img

  1. ForkJoinTask 实现了 Future 接口,说明它也是一个可取消的异步运算任务,实际上ForkJoinTask 是 Future 的轻量级实现,主要用在纯粹是计算的函数式任务或者操作完全独立的对象计算任务。

  2. fork 是主运行方法,用于异步执行;而 join 方法在任务结果计算完毕之后才会运行,用来合并或返回计算结果。

  3. 其内部类都比较简单,ExceptionNode 是用于存储任务执行期间的异常信息的单向链表;其余四个类是为 Runnable/Callable 任务提供的适配器类,用于把 Runnable/Callable 转化为 ForkJoinTask 类型的任务(因为 ForkJoinPool 只可以运行 ForkJoinTask 类型的任务)。

# 源码分析

# 执行流程 - 外部任务(external/submissions task)提交

# 执行流程: 子任务(Worker task)提交

# 执行流程: 任务执行

# 获取任务结果 - ForkJoinTask.join() / ForkJoinTask.invoke()

# 常见问题

# 有哪些JDK源码中使用了Fork/Join思想?

  1. 我们常用的数组工具类 Arrays 在JDK 8之后新增的并行排序方法(parallelSort)就运用了 ForkJoinPool 的特性

  2. 还有 ConcurrentHashMap 在JDK 8之后添加的函数式方法(如forEach等)也有运用。

  3. stream流的map等方法也使用了fork/join思想

# 使用Executors工具类创建ForkJoinPool

// parallelism定义并行级别
public static ExecutorService newWorkStealingPool(int parallelism);
// 默认并行级别为JVM可用的处理器个数
// Runtime.getRuntime().availableProcessors()
public static ExecutorService newWorkStealingPool();

# Fork/Join示例

# 采用Fork/Join来异步计算1+2+3+…+10000的结果

public class Test {
	static final class SumTask extends RecursiveTask<Integer> {
		private static final long serialVersionUID = 1L;
		
		final int start; //开始计算的数
		final int end; //最后计算的数
		
		SumTask(int start, int end) {
			this.start = start;
			this.end = end;
		}

		@Override
		protected Integer compute() {
			//如果计算量小于1000,那么分配一个线程执行if中的代码块,并返回执行结果
			if(end - start < 1000) {
				System.out.println(Thread.currentThread().getName() + " 开始执行: " + start + "-" + end);
				int sum = 0;
				for(int i = start; i <= end; i++)
					sum += i;
				return sum;
			}
			//如果计算量大于1000,那么拆分为两个任务
			SumTask task1 = new SumTask(start, (start + end) / 2);
			SumTask task2 = new SumTask((start + end) / 2 + 1, end);
			//执行任务
			task1.fork();
			task2.fork();
			//获取任务执行的结果
			return task1.join() + task2.join();
		}
	}
	
	public static void main(String[] args) throws InterruptedException, ExecutionException {
		ForkJoinPool pool = new ForkJoinPool();
		ForkJoinTask<Integer> task = new SumTask(1, 10000);
		pool.submit(task);
		System.out.println(task.get());
	}
}

结果:

ForkJoinPool-1-worker-1 开始执行: 1-625\
ForkJoinPool-1-worker-7 开始执行: 6251-6875\
ForkJoinPool-1-worker-6 开始执行: 5626-6250\
ForkJoinPool-1-worker-10 开始执行: 3751-4375\
ForkJoinPool-1-worker-13 开始执行: 2501-3125\
ForkJoinPool-1-worker-8 开始执行: 626-1250\
ForkJoinPool-1-worker-11 开始执行: 5001-5625\
ForkJoinPool-1-worker-3 开始执行: 7501-8125\
ForkJoinPool-1-worker-14 开始执行: 1251-1875\
ForkJoinPool-1-worker-4 开始执行: 9376-10000\
ForkJoinPool-1-worker-8 开始执行: 8126-8750\
ForkJoinPool-1-worker-0 开始执行: 1876-2500\
ForkJoinPool-1-worker-12 开始执行: 4376-5000\
ForkJoinPool-1-worker-5 开始执行: 8751-9375\
ForkJoinPool-1-worker-7 开始执行: 6876-7500\
ForkJoinPool-1-worker-1 开始执行: 3126-3750\
50005000

# 实现斐波那契数列

public static void main(String[] args) {
    ForkJoinPool forkJoinPool = new ForkJoinPool(4); // 最大并发数4
    Fibonacci fibonacci = new Fibonacci(20);
    long startTime = System.currentTimeMillis();
    Integer result = forkJoinPool.invoke(fibonacci);
    long endTime = System.currentTimeMillis();
    System.out.println("Fork/join sum: " + result + " in " + (endTime - startTime) + " ms.");
}
//以下为官方API文档示例
static  class Fibonacci extends RecursiveTask<Integer> {
    final int n;
    Fibonacci(int n) {
        this.n = n;
    }
    @Override
    protected Integer compute() {
        if (n <= 1) {
            return n;
        }
        Fibonacci f1 = new Fibonacci(n - 1);
        f1.fork(); 
        Fibonacci f2 = new Fibonacci(n - 2);
        return f2.compute() + f1.join(); 
    }
}

# 参考

JUC线程池: Fork/Join框架详解 | Java 全栈知识体系 (opens new window)

Java Fork Join 框架 | JAVACORE (opens new window)

Last Updated: 3/29/2022, 9:08:15 AM