001package org.junit.runners.model;
002
003import static java.lang.reflect.Modifier.isStatic;
004import static org.junit.internal.MethodSorter.NAME_ASCENDING;
005
006import java.lang.annotation.Annotation;
007import java.lang.reflect.Constructor;
008import java.lang.reflect.Field;
009import java.lang.reflect.Method;
010import java.lang.reflect.Modifier;
011import java.util.ArrayList;
012import java.util.Arrays;
013import java.util.Collections;
014import java.util.Comparator;
015import java.util.LinkedHashMap;
016import java.util.LinkedHashSet;
017import java.util.List;
018import java.util.Map;
019import java.util.Set;
020
021import org.junit.Assert;
022import org.junit.Before;
023import org.junit.BeforeClass;
024import org.junit.internal.MethodSorter;
025
026/**
027 * Wraps a class to be run, providing method validation and annotation searching
028 *
029 * @since 4.5
030 */
031public class TestClass implements Annotatable {
032    private static final FieldComparator FIELD_COMPARATOR = new FieldComparator();
033    private static final MethodComparator METHOD_COMPARATOR = new MethodComparator();
034
035    private final Class<?> clazz;
036    private final Map<Class<? extends Annotation>, List<FrameworkMethod>> methodsForAnnotations;
037    private final Map<Class<? extends Annotation>, List<FrameworkField>> fieldsForAnnotations;
038
039    /**
040     * Creates a {@code TestClass} wrapping {@code clazz}. Each time this
041     * constructor executes, the class is scanned for annotations, which can be
042     * an expensive process (we hope in future JDK's it will not be.) Therefore,
043     * try to share instances of {@code TestClass} where possible.
044     */
045    public TestClass(Class<?> clazz) {
046        this.clazz = clazz;
047        if (clazz != null && clazz.getConstructors().length > 1) {
048            throw new IllegalArgumentException(
049                    "Test class can only have one constructor");
050        }
051
052        Map<Class<? extends Annotation>, List<FrameworkMethod>> methodsForAnnotations =
053                new LinkedHashMap<Class<? extends Annotation>, List<FrameworkMethod>>();
054        Map<Class<? extends Annotation>, List<FrameworkField>> fieldsForAnnotations =
055                new LinkedHashMap<Class<? extends Annotation>, List<FrameworkField>>();
056
057        scanAnnotatedMembers(methodsForAnnotations, fieldsForAnnotations);
058
059        this.methodsForAnnotations = makeDeeplyUnmodifiable(methodsForAnnotations);
060        this.fieldsForAnnotations = makeDeeplyUnmodifiable(fieldsForAnnotations);
061    }
062
063    protected void scanAnnotatedMembers(Map<Class<? extends Annotation>, List<FrameworkMethod>> methodsForAnnotations, Map<Class<? extends Annotation>, List<FrameworkField>> fieldsForAnnotations) {
064        for (Class<?> eachClass : getSuperClasses(clazz)) {
065            for (Method eachMethod : MethodSorter.getDeclaredMethods(eachClass)) {
066                addToAnnotationLists(new FrameworkMethod(eachMethod), methodsForAnnotations);
067            }
068            // ensuring fields are sorted to make sure that entries are inserted
069            // and read from fieldForAnnotations in a deterministic order
070            for (Field eachField : getSortedDeclaredFields(eachClass)) {
071                addToAnnotationLists(new FrameworkField(eachField), fieldsForAnnotations);
072            }
073        }
074    }
075
076    private static Field[] getSortedDeclaredFields(Class<?> clazz) {
077        Field[] declaredFields = clazz.getDeclaredFields();
078        Arrays.sort(declaredFields, FIELD_COMPARATOR);
079        return declaredFields;
080    }
081
082    protected static <T extends FrameworkMember<T>> void addToAnnotationLists(T member,
083            Map<Class<? extends Annotation>, List<T>> map) {
084        for (Annotation each : member.getAnnotations()) {
085            Class<? extends Annotation> type = each.annotationType();
086            List<T> members = getAnnotatedMembers(map, type, true);
087            T memberToAdd = member.handlePossibleBridgeMethod(members);
088            if (memberToAdd == null) {
089                return;
090            }
091            if (runsTopToBottom(type)) {
092                members.add(0, memberToAdd);
093            } else {
094                members.add(memberToAdd);
095            }
096        }
097    }
098
099    private static <T extends FrameworkMember<T>> Map<Class<? extends Annotation>, List<T>>
100            makeDeeplyUnmodifiable(Map<Class<? extends Annotation>, List<T>> source) {
101        Map<Class<? extends Annotation>, List<T>> copy =
102                new LinkedHashMap<Class<? extends Annotation>, List<T>>();
103        for (Map.Entry<Class<? extends Annotation>, List<T>> entry : source.entrySet()) {
104            copy.put(entry.getKey(), Collections.unmodifiableList(entry.getValue()));
105        }
106        return Collections.unmodifiableMap(copy);
107    }
108
109    /**
110     * Returns, efficiently, all the non-overridden methods in this class and
111     * its superclasses that are annotated}.
112     * 
113     * @since 4.12
114     */
115    public List<FrameworkMethod> getAnnotatedMethods() {
116        List<FrameworkMethod> methods = collectValues(methodsForAnnotations);
117        Collections.sort(methods, METHOD_COMPARATOR);
118        return methods;
119    }
120
121    /**
122     * Returns, efficiently, all the non-overridden methods in this class and
123     * its superclasses that are annotated with {@code annotationClass}.
124     */
125    public List<FrameworkMethod> getAnnotatedMethods(
126            Class<? extends Annotation> annotationClass) {
127        return Collections.unmodifiableList(getAnnotatedMembers(methodsForAnnotations, annotationClass, false));
128    }
129
130    /**
131     * Returns, efficiently, all the non-overridden fields in this class and its
132     * superclasses that are annotated.
133     * 
134     * @since 4.12
135     */
136    public List<FrameworkField> getAnnotatedFields() {
137        return collectValues(fieldsForAnnotations);
138    }
139
140    /**
141     * Returns, efficiently, all the non-overridden fields in this class and its
142     * superclasses that are annotated with {@code annotationClass}.
143     */
144    public List<FrameworkField> getAnnotatedFields(
145            Class<? extends Annotation> annotationClass) {
146        return Collections.unmodifiableList(getAnnotatedMembers(fieldsForAnnotations, annotationClass, false));
147    }
148
149    private <T> List<T> collectValues(Map<?, List<T>> map) {
150        Set<T> values = new LinkedHashSet<T>();
151        for (List<T> additionalValues : map.values()) {
152            values.addAll(additionalValues);
153        }
154        return new ArrayList<T>(values);
155    }
156
157    private static <T> List<T> getAnnotatedMembers(Map<Class<? extends Annotation>, List<T>> map,
158            Class<? extends Annotation> type, boolean fillIfAbsent) {
159        if (!map.containsKey(type) && fillIfAbsent) {
160            map.put(type, new ArrayList<T>());
161        }
162        List<T> members = map.get(type);
163        return members == null ? Collections.<T>emptyList() : members;
164    }
165
166    private static boolean runsTopToBottom(Class<? extends Annotation> annotation) {
167        return annotation.equals(Before.class)
168                || annotation.equals(BeforeClass.class);
169    }
170
171    private static List<Class<?>> getSuperClasses(Class<?> testClass) {
172        List<Class<?>> results = new ArrayList<Class<?>>();
173        Class<?> current = testClass;
174        while (current != null) {
175            results.add(current);
176            current = current.getSuperclass();
177        }
178        return results;
179    }
180
181    /**
182     * Returns the underlying Java class.
183     */
184    public Class<?> getJavaClass() {
185        return clazz;
186    }
187
188    /**
189     * Returns the class's name.
190     */
191    public String getName() {
192        if (clazz == null) {
193            return "null";
194        }
195        return clazz.getName();
196    }
197
198    /**
199     * Returns the only public constructor in the class, or throws an {@code
200     * AssertionError} if there are more or less than one.
201     */
202
203    public Constructor<?> getOnlyConstructor() {
204        Constructor<?>[] constructors = clazz.getConstructors();
205        Assert.assertEquals(1, constructors.length);
206        return constructors[0];
207    }
208
209    /**
210     * Returns the annotations on this class
211     */
212    public Annotation[] getAnnotations() {
213        if (clazz == null) {
214            return new Annotation[0];
215        }
216        return clazz.getAnnotations();
217    }
218
219    public <T extends Annotation> T getAnnotation(Class<T> annotationType) {
220        if (clazz == null) {
221            return null;
222        }
223        return clazz.getAnnotation(annotationType);
224    }
225
226    public <T> List<T> getAnnotatedFieldValues(Object test,
227            Class<? extends Annotation> annotationClass, Class<T> valueClass) {
228        final List<T> results = new ArrayList<T>();
229        collectAnnotatedFieldValues(test, annotationClass, valueClass,
230                new MemberValueConsumer<T>() {
231                    public void accept(FrameworkMember<?> member, T value) {
232                        results.add(value);
233                    }
234                });
235        return results;
236    }
237
238    /**
239     * Finds the fields annotated with the specified annotation and having the specified type,
240     * retrieves the values and passes those to the specified consumer.
241     *
242     * @since 4.13
243     */
244    public <T> void collectAnnotatedFieldValues(Object test,
245            Class<? extends Annotation> annotationClass, Class<T> valueClass,
246            MemberValueConsumer<T> consumer) {
247        for (FrameworkField each : getAnnotatedFields(annotationClass)) {
248            try {
249                Object fieldValue = each.get(test);
250                if (valueClass.isInstance(fieldValue)) {
251                    consumer.accept(each, valueClass.cast(fieldValue));
252                }
253            } catch (IllegalAccessException e) {
254                throw new RuntimeException(
255                        "How did getFields return a field we couldn't access?", e);
256            }
257        }
258    }
259
260    public <T> List<T> getAnnotatedMethodValues(Object test,
261            Class<? extends Annotation> annotationClass, Class<T> valueClass) {
262        final List<T> results = new ArrayList<T>();
263        collectAnnotatedMethodValues(test, annotationClass, valueClass,
264                new MemberValueConsumer<T>() {
265                    public void accept(FrameworkMember<?> member, T value) {
266                        results.add(value);
267                    }
268                });
269        return results;
270    }
271
272    /**
273     * Finds the methods annotated with the specified annotation and returning the specified type,
274     * invokes it and pass the return value to the specified consumer.
275     *
276     * @since 4.13
277     */
278    public <T> void collectAnnotatedMethodValues(Object test,
279            Class<? extends Annotation> annotationClass, Class<T> valueClass,
280            MemberValueConsumer<T> consumer) {
281        for (FrameworkMethod each : getAnnotatedMethods(annotationClass)) {
282            try {
283                /*
284                 * A method annotated with @Rule may return a @TestRule or a @MethodRule,
285                 * we cannot call the method to check whether the return type matches our
286                 * expectation i.e. subclass of valueClass. If we do that then the method 
287                 * will be invoked twice and we do not want to do that. So we first check
288                 * whether return type matches our expectation and only then call the method
289                 * to fetch the MethodRule
290                 */
291                if (valueClass.isAssignableFrom(each.getReturnType())) {
292                    Object fieldValue = each.invokeExplosively(test);
293                    consumer.accept(each, valueClass.cast(fieldValue));
294                }
295            } catch (Throwable e) {
296                throw new RuntimeException(
297                        "Exception in " + each.getName(), e);
298            }
299        }
300    }
301
302    public boolean isPublic() {
303        return Modifier.isPublic(clazz.getModifiers());
304    }
305
306    public boolean isANonStaticInnerClass() {
307        return clazz.isMemberClass() && !isStatic(clazz.getModifiers());
308    }
309
310    @Override
311    public int hashCode() {
312        return (clazz == null) ? 0 : clazz.hashCode();
313    }
314
315    @Override
316    public boolean equals(Object obj) {
317        if (this == obj) {
318            return true;
319        }
320        if (obj == null) {
321            return false;
322        }
323        if (getClass() != obj.getClass()) {
324            return false;
325        }
326        TestClass other = (TestClass) obj;
327        return clazz == other.clazz;
328    }
329
330    /**
331     * Compares two fields by its name.
332     */
333    private static class FieldComparator implements Comparator<Field> {
334        public int compare(Field left, Field right) {
335            return left.getName().compareTo(right.getName());
336        }
337    }
338
339    /**
340     * Compares two methods by its name.
341     */
342    private static class MethodComparator implements
343            Comparator<FrameworkMethod> {
344        public int compare(FrameworkMethod left, FrameworkMethod right) {
345            return NAME_ASCENDING.compare(left.getMethod(), right.getMethod());
346        }
347    }
348}