001package org.junit.runners.parameterized;
002
003import java.lang.annotation.Annotation;
004import java.lang.reflect.Field;
005import java.util.List;
006
007import org.junit.internal.runners.statements.RunAfters;
008import org.junit.internal.runners.statements.RunBefores;
009import org.junit.runner.RunWith;
010import org.junit.runner.notification.RunNotifier;
011import org.junit.runners.BlockJUnit4ClassRunner;
012import org.junit.runners.Parameterized;
013import org.junit.runners.Parameterized.Parameter;
014import org.junit.runners.model.FrameworkField;
015import org.junit.runners.model.FrameworkMethod;
016import org.junit.runners.model.InitializationError;
017import org.junit.runners.model.Statement;
018
019/**
020 * A {@link BlockJUnit4ClassRunner} with parameters support. Parameters can be
021 * injected via constructor or into annotated fields.
022 */
023public class BlockJUnit4ClassRunnerWithParameters extends
024        BlockJUnit4ClassRunner {
025    private enum InjectionType {
026        CONSTRUCTOR, FIELD
027    }
028
029    private final Object[] parameters;
030
031    private final String name;
032
033    public BlockJUnit4ClassRunnerWithParameters(TestWithParameters test)
034            throws InitializationError {
035        super(test.getTestClass());
036        parameters = test.getParameters().toArray(
037                new Object[test.getParameters().size()]);
038        name = test.getName();
039    }
040
041    @Override
042    public Object createTest() throws Exception {
043        InjectionType injectionType = getInjectionType();
044        switch (injectionType) {
045            case CONSTRUCTOR:
046                return createTestUsingConstructorInjection();
047            case FIELD:
048                return createTestUsingFieldInjection();
049            default:
050                throw new IllegalStateException("The injection type "
051                        + injectionType + " is not supported.");
052        }
053    }
054
055    private Object createTestUsingConstructorInjection() throws Exception {
056        return getTestClass().getOnlyConstructor().newInstance(parameters);
057    }
058
059    private Object createTestUsingFieldInjection() throws Exception {
060        List<FrameworkField> annotatedFieldsByParameter = getAnnotatedFieldsByParameter();
061        if (annotatedFieldsByParameter.size() != parameters.length) {
062            throw new Exception(
063                    "Wrong number of parameters and @Parameter fields."
064                            + " @Parameter fields counted: "
065                            + annotatedFieldsByParameter.size()
066                            + ", available parameters: " + parameters.length
067                            + ".");
068        }
069        Object testClassInstance = getTestClass().getJavaClass().newInstance();
070        for (FrameworkField each : annotatedFieldsByParameter) {
071            Field field = each.getField();
072            Parameter annotation = field.getAnnotation(Parameter.class);
073            int index = annotation.value();
074            try {
075                field.set(testClassInstance, parameters[index]);
076            } catch (IllegalAccessException e) {
077                IllegalAccessException wrappedException = new IllegalAccessException(
078                        "Cannot set parameter '" + field.getName()
079                                + "'. Ensure that the field '" + field.getName()
080                                + "' is public.");
081                wrappedException.initCause(e);
082                throw wrappedException;
083            } catch (IllegalArgumentException iare) {
084                throw new Exception(getTestClass().getName()
085                        + ": Trying to set " + field.getName()
086                        + " with the value " + parameters[index]
087                        + " that is not the right type ("
088                        + parameters[index].getClass().getSimpleName()
089                        + " instead of " + field.getType().getSimpleName()
090                        + ").", iare);
091            }
092        }
093        return testClassInstance;
094    }
095
096    @Override
097    protected String getName() {
098        return name;
099    }
100
101    @Override
102    protected String testName(FrameworkMethod method) {
103        return method.getName() + getName();
104    }
105
106    @Override
107    protected void validateConstructor(List<Throwable> errors) {
108        validateOnlyOneConstructor(errors);
109        if (getInjectionType() != InjectionType.CONSTRUCTOR) {
110            validateZeroArgConstructor(errors);
111        }
112    }
113
114    @Override
115    protected void validateFields(List<Throwable> errors) {
116        super.validateFields(errors);
117        if (getInjectionType() == InjectionType.FIELD) {
118            List<FrameworkField> annotatedFieldsByParameter = getAnnotatedFieldsByParameter();
119            int[] usedIndices = new int[annotatedFieldsByParameter.size()];
120            for (FrameworkField each : annotatedFieldsByParameter) {
121                int index = each.getField().getAnnotation(Parameter.class)
122                        .value();
123                if (index < 0 || index > annotatedFieldsByParameter.size() - 1) {
124                    errors.add(new Exception("Invalid @Parameter value: "
125                            + index + ". @Parameter fields counted: "
126                            + annotatedFieldsByParameter.size()
127                            + ". Please use an index between 0 and "
128                            + (annotatedFieldsByParameter.size() - 1) + "."));
129                } else {
130                    usedIndices[index]++;
131                }
132            }
133            for (int index = 0; index < usedIndices.length; index++) {
134                int numberOfUse = usedIndices[index];
135                if (numberOfUse == 0) {
136                    errors.add(new Exception("@Parameter(" + index
137                            + ") is never used."));
138                } else if (numberOfUse > 1) {
139                    errors.add(new Exception("@Parameter(" + index
140                            + ") is used more than once (" + numberOfUse + ")."));
141                }
142            }
143        }
144    }
145
146    @Override
147    protected Statement classBlock(RunNotifier notifier) {
148        Statement statement = childrenInvoker(notifier);
149        statement = withBeforeParams(statement);
150        statement = withAfterParams(statement);
151        return statement;
152    }
153
154    private Statement withBeforeParams(Statement statement) {
155        List<FrameworkMethod> befores = getTestClass()
156                .getAnnotatedMethods(Parameterized.BeforeParam.class);
157        return befores.isEmpty() ? statement : new RunBeforeParams(statement, befores);
158    }
159
160    private class RunBeforeParams extends RunBefores {
161        RunBeforeParams(Statement next, List<FrameworkMethod> befores) {
162            super(next, befores, null);
163        }
164
165        @Override
166        protected void invokeMethod(FrameworkMethod method) throws Throwable {
167            int paramCount = method.getMethod().getParameterTypes().length;
168            method.invokeExplosively(null, paramCount == 0 ? (Object[]) null : parameters);
169        }
170    }
171
172    private Statement withAfterParams(Statement statement) {
173        List<FrameworkMethod> afters = getTestClass()
174                .getAnnotatedMethods(Parameterized.AfterParam.class);
175        return afters.isEmpty() ? statement : new RunAfterParams(statement, afters);
176    }
177
178    private class RunAfterParams extends RunAfters {
179        RunAfterParams(Statement next, List<FrameworkMethod> afters) {
180            super(next, afters, null);
181        }
182
183        @Override
184        protected void invokeMethod(FrameworkMethod method) throws Throwable {
185            int paramCount = method.getMethod().getParameterTypes().length;
186            method.invokeExplosively(null, paramCount == 0 ? (Object[]) null : parameters);
187        }
188    }
189
190    @Override
191    protected Annotation[] getRunnerAnnotations() {
192        Annotation[] allAnnotations = super.getRunnerAnnotations();
193        Annotation[] annotationsWithoutRunWith = new Annotation[allAnnotations.length - 1];
194        int i = 0;
195        for (Annotation annotation: allAnnotations) {
196            if (!annotation.annotationType().equals(RunWith.class)) {
197                annotationsWithoutRunWith[i] = annotation;
198                ++i;
199            }
200        }
201        return annotationsWithoutRunWith;
202    }
203
204    private List<FrameworkField> getAnnotatedFieldsByParameter() {
205        return getTestClass().getAnnotatedFields(Parameter.class);
206    }
207
208    private InjectionType getInjectionType() {
209        if (fieldsAreAnnotated()) {
210            return InjectionType.FIELD;
211        } else {
212            return InjectionType.CONSTRUCTOR;
213        }
214    }
215
216    private boolean fieldsAreAnnotated() {
217        return !getAnnotatedFieldsByParameter().isEmpty();
218    }
219}