Uploaded image for project: 'JDK'
  1. JDK
  2. JDK-8214761

Bug in parallel Kahan summation implementation

XMLWordPrintable

    • b14
    • x86_64
    • windows_10
    • Verified

        A DESCRIPTION OF THE PROBLEM :
        DoubleStream.sum and related functions (Collectors.averagingDouble, Collectors.summingDouble and possibly others) all use Kahan summation to reduce numerical error. As I understand it, the implementations of these function use a double array where index 0 holds the high-order bits of the running, and index 1 holds the negation of the low-order bits. The documentation incorrectly states that index 1 holds the lower order bits (no negaton), and when combining two running sums incorrectly adds the negation of the low-order bits.

        This problem appears in OpenJDK 8 and 11. I think that in https://hg.openjdk.java.net/jdk/jdk/file/8613f3fdbdae/src/java.base/share/classes/java/util/stream/DoublePipeline.java, line 432 should be changed to

        Collectors.sumWithCompensation(ll, -rr[1]);

        and in https://hg.openjdk.java.net/jdk/jdk/file/8613f3fdbdae/src/java.base/share/classes/java/util/stream/Collectors.java, line 729 should be changed to

        return sumWithCompensation(a, -b[1]); },

        and line 841 should be changed to

        (a, b) -> { sumWithCompensation(a, b[0]); sumWithCompensation(a, -b[1]); a[2] += b[2]; a[3] += b[3]; return a; },

        Alternatively, sumWithCompensation could be altered instead.

        I've attached test code below, comparing the result of sum() using the current implementation, and the result using the flipped sign. The results of the sequential and niave sum are given for comparison, and the sum of the squared errors (relative to a base case using sequential Kahan summation) are printed. With the random inputs used, the sum of squared errors is consistently lower with the proposed sign flip.


        ---------- BEGIN SOURCE ----------
        package test;

        import java.util.Random;
        import java.util.stream.DoubleStream;

        public class TestSum {
            
            public static void main(String [] args) {
                double naive = 0;
                double sequentialStream = 0;
                double parallelStream = 0;
                double mySequentialStream = 0;
                double myParallelStream = 0;
                
                for (int loop = 0; loop < 100; loop++) {
                    // sequence of random numbers of varying magnitudes, both positive and negative
                    double[] rand = new Random().doubles(1_000_000)
                            .map(Math::log)
                            .map(x -> (Double.doubleToLongBits(x) % 2 == 0) ? x : -x)
                            .toArray();
                    
                    // base case: standard Kahan summation
                    double[] sum = new double[2];
                    for (int i=0; i < rand.length; i++) {
                        sumWithCompensation(sum, rand[i]);
                    }
                    
                    // squared error of naive sum by reduction - should be large
                    naive += Math.pow(DoubleStream.of(rand).reduce((x, y) -> x+y).getAsDouble() - sum[0], 2);
                    
                    // squared error of sequential sum - should be 0
                    sequentialStream += Math.pow(DoubleStream.of(rand).sum() - sum[0], 2);
                    
                    // squared error of parallel sum
                    parallelStream += Math.pow(DoubleStream.of(rand).parallel().sum() - sum[0], 2);
                    
                    // squared error of modified sequential sum - should be 0
                    mySequentialStream += Math.pow(computeFinalSum(DoubleStream.of(rand).collect(
                            () -> new double[3],
                            (ll, d) -> {
                                sumWithCompensation(ll, d);
                                ll[2] += d;
                            },
                            (ll, rr) -> {
                                sumWithCompensation(ll, rr[0]);
                                sumWithCompensation(ll, -rr[1]); // minus is added
                                ll[2] += rr[2];
                            })) - sum[0], 2);
                    
                    // squared error of modified parallel sum - typically ~0.25-0.5 times squared error of parallel sum
                    myParallelStream += Math.pow(computeFinalSum(DoubleStream.of(rand).parallel().collect(
                            () -> new double[3],
                            (ll, d) -> {
                                sumWithCompensation(ll, d);
                                ll[2] += d;
                            },
                            (ll, rr) -> {
                                sumWithCompensation(ll, rr[0]);
                                sumWithCompensation(ll, -rr[1]); // minus is added
                                ll[2] += rr[2];
                            })) - sum[0], 2);
                }
                
                // print sum of squared errors
                System.out.println(naive);
                System.out.println(sequentialStream);
                System.out.println(parallelStream);
                System.out.println(mySequentialStream);
                System.out.println(myParallelStream);
            }
            
            // from OpenJDK8 Collectors, unmodified
            static double[] sumWithCompensation(double[] intermediateSum, double value) {
                double tmp = value - intermediateSum[1];
                double sum = intermediateSum[0];
                double velvel = sum + tmp; // Little wolf of rounding error
                intermediateSum[1] = (velvel - sum) - tmp;
                intermediateSum[0] = velvel;
                return intermediateSum;
            }
            
            // from OpenJDK8 Collectors, unmodified
            static double computeFinalSum(double[] summands) {
                double tmp = summands[0] + summands[1];
                double simpleSum = summands[summands.length - 1];
                if (Double.isNaN(tmp) && Double.isInfinite(simpleSum))
                    return simpleSum;
                else
                    return tmp;
            }

        }

        ---------- END SOURCE ----------

        FREQUENCY : always


              igraves Ian Graves
              webbuggrp Webbug Group
              Votes:
              0 Vote for this issue
              Watchers:
              9 Start watching this issue

                Created:
                Updated:
                Resolved: