001package org.junit.experimental;
002
003import java.util.concurrent.ExecutorService;
004import java.util.concurrent.Executors;
005import java.util.concurrent.TimeUnit;
006
007import org.junit.runner.Computer;
008import org.junit.runner.Runner;
009import org.junit.runners.ParentRunner;
010import org.junit.runners.model.InitializationError;
011import org.junit.runners.model.RunnerBuilder;
012import org.junit.runners.model.RunnerScheduler;
013
014public class ParallelComputer extends Computer {
015    private final boolean classes;
016
017    private final boolean methods;
018
019    public ParallelComputer(boolean classes, boolean methods) {
020        this.classes = classes;
021        this.methods = methods;
022    }
023
024    public static Computer classes() {
025        return new ParallelComputer(true, false);
026    }
027
028    public static Computer methods() {
029        return new ParallelComputer(false, true);
030    }
031
032    private static Runner parallelize(Runner runner) {
033        if (runner instanceof ParentRunner) {
034            ((ParentRunner<?>) runner).setScheduler(new RunnerScheduler() {
035                private final ExecutorService fService = Executors.newCachedThreadPool();
036
037                public void schedule(Runnable childStatement) {
038                    fService.submit(childStatement);
039                }
040
041                public void finished() {
042                    try {
043                        fService.shutdown();
044                        fService.awaitTermination(Long.MAX_VALUE, TimeUnit.NANOSECONDS);
045                    } catch (InterruptedException e) {
046                        e.printStackTrace(System.err);
047                    }
048                }
049            });
050        }
051        return runner;
052    }
053
054    @Override
055    public Runner getSuite(RunnerBuilder builder, java.lang.Class<?>[] classes)
056            throws InitializationError {
057        Runner suite = super.getSuite(builder, classes);
058        return this.classes ? parallelize(suite) : suite;
059    }
060
061    @Override
062    protected Runner getRunner(RunnerBuilder builder, Class<?> testClass)
063            throws Throwable {
064        Runner runner = super.getRunner(builder, testClass);
065        return methods ? parallelize(runner) : runner;
066    }
067}