Uploaded image for project: 'JDK'
  1. JDK
  2. JDK-8146071

Math.max in unrolled loops may produce less optimal code than explicit comparison

XMLWordPrintable

    • Icon: Bug Bug
    • Resolution: Duplicate
    • Icon: P4 P4
    • None
    • 8, 9
    • hotspot
    • x86_64
    • windows_7

      FULL PRODUCT VERSION :
      c:\work\java\jaxenter>java -version
      java version "1.8.0_65"
      Java(TM) SE Runtime Environment (build 1.8.0_65-b17)
      Java HotSpot(TM) 64-Bit Server VM (build 25.65-b01, mixed mode)

      c:\work\java\jaxenter>javac -version
      javac 1.8.0_65


      ADDITIONAL OS VERSION INFORMATION :
      Microsoft Windows [Version 6.1.7601] System 7 x64

      EXTRA RELEVANT SYSTEM CONFIGURATION :
      c:\work\java\jaxenter>java -version
      java version "1.8.0_65"
      Java(TM) SE Runtime Environment (build 1.8.0_65-b17)
      Java HotSpot(TM) 64-Bit Server VM (build 25.65-b01, mixed mode)

      c:\work\java\jaxenter>javac -version
      javac 1.8.0_65


      A DESCRIPTION OF THE PROBLEM :
      streamTest() runs 9X slower than streamTestA(), forTest(), and forTestA(). It appears that the optimizer is not considering that Math::max is a final method. Note that streamTestB() and streamTestC() are only 2X slower than streamTestA() even given that streamTestC() passed the Math::max method reference to the compare method whose containing lambda is passed to reduce().

      This bug report is derived from an article posted at https://jaxenter.com/follow-up-how-fast-are-the-java-8-streams-122522.html.

      STEPS TO FOLLOW TO REPRODUCE THE PROBLEM :
      Execute the provide test case: java ForVsStream


      EXPECTED VERSUS ACTUAL BEHAVIOR :
      EXPECTED -
      streamTest() should run at the same speed as streamTestA(), forTest() and forTestA()
      ACTUAL -
      streamTest() is 9X slower than streamTest(), forTest(), and forTest() while streamTestB() and streamTestC() are 4X faster than streamTest().

      REPRODUCIBILITY :
      This bug can be reproduced always.

      ---------- BEGIN SOURCE ----------
      import java.lang.Math;
      import java.util.Arrays;
      import java.util.Random;
      import static java.lang.System.currentTimeMillis;
      import static java.lang.System.nanoTime;

      class ForVsStream {

        private static final void print (String string) {
          System.out.print(string);
        }

        private static final void println (String string) {
          System.out.println(string);
        }

        private static volatile long time = 0;

        private static int nanoTimed = 0;

        private static final long nanoTimerOverhead () {
          final int bound = 10000;
          int total = 0;
          for (int count = 0; count < bound; ++count) {
            long start = nanoTime();
            time = nanoTime();
            long finish = nanoTime();
            total += (finish - start) / 2;
          }
          nanoTimed = total / bound;
          return nanoTimed;
        }

        private static final long milliTimerOverhead () {
          final int bound = 10000;
          int total = 0;
          for (int count = 0; count < bound; ++count) {
            long start = nanoTime();
            time = currentTimeMillis();
            long finish = nanoTime();
            total += finish - start - nanoTimed;
          }
          return total / bound;
        }

        static final int[] ints = new int[1000000];

        private static final long randomInts () {
          long start = currentTimeMillis();
          Random random = new Random();
          for (int i = 0; i < ints.length; i++)
            ints[i] = random.nextInt();
          long finish = currentTimeMillis();
          return finish - start;
        }

        private static final long forTest () {
          int[] a = ints;
          int e = ints.length;
          int m = Integer.MIN_VALUE;
          long start = currentTimeMillis();
          for (int i = 0; i < e; i++)
            if (a[i] > m)
              m = a[i];
          long finish = currentTimeMillis();
          println("m: " + m);
          return finish - start;
        }

        private static final long forTestA () {
          int[] a = ints;
          int e = ints.length;
          int m = Integer.MIN_VALUE;
          long start = currentTimeMillis();
          for (int i = 0; i < e; i++)
            m = Math.max(m, a[i]);
          long finish = currentTimeMillis();
          println("m: " + m);
          return finish - start;
        }

        private static final long streamTest () {
          long start = currentTimeMillis();
          int m = Arrays.stream(ints).reduce(Integer.MIN_VALUE, Math::max);
          long finish = currentTimeMillis();
          println("m: " + m);
          return finish - start;
        }

        private static final long streamTestA () {
          long start = currentTimeMillis();
          int m = Arrays.stream(ints).reduce(Integer.MIN_VALUE, (a, b) -> a > b ? a : b);
          long finish = currentTimeMillis();
          println("m: " + m);
          return finish - start;
        }

        private static final long streamTestB () {
          long start = currentTimeMillis();
          int m = Arrays.stream(ints).reduce(Integer.MIN_VALUE, (a, b) -> Math.max(a, b));
          long finish = currentTimeMillis();
          println("m: " + m);
          return finish - start;
        }

        public interface BinaryCompareFunction<T> {
          public T compare (T a1, T a2);
        }

        private static final <T> T compare (BinaryCompareFunction<T> f, T a1, T a2) {
          return f.compare(a1, a2);
        }

        private static final long streamTestC () {
          long start = currentTimeMillis();
          int m = Arrays.stream(ints).reduce(Integer.MIN_VALUE, (a, b) -> compare(Math::max, a, b));
          long finish = currentTimeMillis();
          println("m: " + m);
          return finish - start;
        }

        public static void main (String[] argumets) {
          println("nanoTimerOverhead: " + nanoTimerOverhead() + " nanoseconds");
          println("milliTimerOverhead: " + milliTimerOverhead() + " nanoseconds");
          println("randomInts(): " + randomInts() + " milliseconds");
          println("forTest(): " + forTest() + " milliseconds");
          println("forTestA(): " + forTestA() + " milliseconds");
          println("streamTest(): " + streamTest() + " milliseconds");
          println("streamTestA(): " + streamTestA() + " milliseconds");
          println("streamTestB(): " + streamTestB() + " milliseconds");
          println("streamTestC(): " + streamTestC() + " milliseconds");
        }

      }

      ---------- END SOURCE ----------

      CUSTOMER SUBMITTED WORKAROUND :
      See streamTestA(), streamTestB(), and streamTestC(). Note that streamTestA() is 2X faster than streamTestB() and streamTestC(), but all are 4X to 9X faster than streamTest().

            psandoz Paul Sandoz
            webbuggrp Webbug Group
            Votes:
            0 Vote for this issue
            Watchers:
            5 Start watching this issue

              Created:
              Updated:
              Resolved: