gpt4 book ai didi

java - 是否可以使用 Executors.newWorkStealingPool() 编写递归 fork join 解决方案?

转载 作者:行者123 更新时间:2023-12-01 09:01:26 27 4
gpt4 key购买 nike

下面的代码旨在展示递归 fork join(查找最大值)的简单使用,我知道 Java JIT 可以在简单的单线程循环中更快地实现这一点,但这只是为了演示。

我最初使用 ForkJoin 框架实现了 find max,该框架对于大型 double 组 (1024*1024) 效果很好。

我觉得我应该能够使用 ForkJoin 框架实现相同的效果,仅使用 Executor.workStealingPool() 和 Callables/Futures。

这可能吗?

我的尝试如下:

class MaxTask implements Callable<Double> {

private double[] array;
private ExecutorService executorService;
public MaxTask(double[] array, ExecutorService es){
this.array = array;
this.executorService = es;
}
@Override
public Double call() throws Exception {
if (this.array.length!=2){
double[] a = new double[(this.array.length/2)];
double[] b = new double[(this.array.length/2)];
for (int i=0;i<(this.array.length/2);i++){
a[i] = array[i];
b[i] = array[i+(this.array.length/2)];
}
Future<Double> f1 = this.executorService.submit(new MaxTask(a,this.executorService));
Future<Double> f2 = this.executorService.submit(new MaxTask(b,this.executorService));

return Math.max(f1.get(), f2.get());
} else {
return Math.max(this.array[0], this.array[1]);
}
}

}

ExecutorService es = Executors.newWorkStealingPool();

double[] x = new double[1024*1024];
for (int i=0;i<x.length;i++){
x[i] = Math.random();
}

MaxTask mt = new MaxTask(x,es);

es.submit(mt).get();

最佳答案

似乎可以在没有 ForkJoin 框架的情况下编写“fork/join”类型的计算(请参阅下面 Callable 的使用)。ForkJoin 框架本身似乎没有性能差异,但代码可能更整洁,我更喜欢只使用 Callables。

我还修复了最初的尝试。看起来最初的尝试阈值太小,这就是速度慢的原因,我认为它至少需要与核心数量一样大。

我不确定使用 ForkJoinPool 是否会更快,我需要收集更多统计数据,我认为不会,因为它没有任何长时间阻塞的操作。

public class Main {

static class FindMaxTask extends RecursiveTask<Double> {

private int threshold;
private double[] data;
private int startIndex;
private int endIndex;

public FindMaxTask(double[] data, int startIndex, int endIndex, int threshold) {
super();
this.data = data;
this.startIndex = startIndex;
this.endIndex = endIndex;
this.threshold = threshold;
}


@Override
protected Double compute() {
int diff = (endIndex-startIndex+1);
if (diff!=(this.data.length/threshold)){
int aStartIndex = startIndex;
int aEndIndex = startIndex + (diff/2) - 1;
int bStartIndex = startIndex + (diff/2);
int bEndIndex = endIndex;

FindMaxTask f1 = new FindMaxTask(this.data,aStartIndex,aEndIndex,threshold);
f1.fork();
FindMaxTask f2 = new FindMaxTask(this.data,bStartIndex,bEndIndex,threshold);
return Math.max(f1.join(),f2.compute());
} else {
double max = Double.MIN_VALUE;
for (int i = startIndex; i <= endIndex; i++) {
double n = data[i];
if (n > max) {
max = n;
}
}
return max;
}
}

}

static class FindMax implements Callable<Double> {

private double[] data;
private int startIndex;
private int endIndex;
private int threshold;

private ExecutorService executorService;

public FindMax(double[] data, int startIndex, int endIndex, int threshold, ExecutorService executorService) {
super();
this.data = data;
this.startIndex = startIndex;
this.endIndex = endIndex;
this.executorService = executorService;
this.threshold = threshold;
}



@Override
public Double call() throws Exception {
int diff = (endIndex-startIndex+1);
if (diff!=(this.data.length/this.threshold)){
int aStartIndex = startIndex;
int aEndIndex = startIndex + (diff/2) - 1;
int bStartIndex = startIndex + (diff/2);
int bEndIndex = endIndex;

Future<Double> f1 = this.executorService.submit(new FindMax(this.data,aStartIndex,aEndIndex,this.threshold,this.executorService));
Future<Double> f2 = this.executorService.submit(new FindMax(this.data,bStartIndex,bEndIndex,this.threshold,this.executorService));
return Math.max(f1.get(), f2.get());
} else {
double max = Double.MIN_VALUE;
for (int i = startIndex; i <= endIndex; i++) {
double n = data[i];
if (n > max) {
max = n;
}
}
return max;
}
}

}

public static void main(String[] args) throws InterruptedException, ExecutionException {

double[] data = new double[1024*1024*64];
for (int i=0;i<data.length;i++){
data[i] = Math.random();
}

int p = Runtime.getRuntime().availableProcessors();
int threshold = p;
int threads = p;
Instant start = null;
Instant end = null;

ExecutorService es = null;
es = Executors.newFixedThreadPool(threads);
System.out.println("1. started..");
start = Instant.now();
System.out.println("max = "+es.submit(new FindMax(data,0,data.length-1,threshold,es)).get());
end = Instant.now();
System.out.println("Callable (recrusive), with fixed pool, Find Max took ms = "+ Duration.between(start, end).toMillis());

es = new ForkJoinPool();
System.out.println("2. started..");
start = Instant.now();
System.out.println("max = "+es.submit(new FindMax(data,0,data.length-1,threshold,es)).get());
end = Instant.now();
System.out.println("Callable (recursive), with fork join pool, Find Max took ms = "+ Duration.between(start, end).toMillis());

ForkJoinPool fj = new ForkJoinPool(threads);
System.out.println("3. started..");
start = Instant.now();
System.out.println("max = "+fj.invoke(new FindMaxTask(data,0,data.length-1,threshold)));
end = Instant.now();
System.out.println("RecursiveTask (fork/join framework),with fork join pool, Find Max took ms = "+ Duration.between(start, end).toMillis());
}

}

关于java - 是否可以使用 Executors.newWorkStealingPool() 编写递归 fork join 解决方案?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/41634320/

27 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com