Uploaded image for project: 'JDK'
  1. JDK
  2. JDK-8273056

java.util.random does not correctly sample exponential or Gaussian distributions

XMLWordPrintable

    • b27
    • generic
    • generic

        A DESCRIPTION OF THE PROBLEM :
        The modified ziggurat algorithm is not correctly implemented in java.base/jdk/internal/util/random/RandomSupport.java.

        Create a histogram of a million samples using 2000 uniform bins with the following range:
        Exponential range from 0 to 12. Gaussian range from -8 to 8.

        This does not pass a Chi-square test. If you look at the histogram it is obviously not showing the shape of the PDF for these distributions. Look closely at the range around zero (e.g. +/- 0.5).

        The following steps can be used to correct the implementation:

        Exponential:
        1. When the sample is not within the main ziggurat the deviate U1 is recycled, including the sign bit. If the next region is selected as an overhang (j>0) the sign bit must be cleared from U1. This is done in the inner loop when creating a new U1 but not for the first entry to that loop. Corrected using:

            if (j > 0) { // Sample overhang j
                // U1 is recycled bits. It must be positive.
                U1 = U1 >>> 1;

        2. When the loop executes to sample the overhang (j>0) the value x is computed from U1. However reflection in the hypotenuse will swap U1 and U2. So x is now invalid as it corresponds to U2. x should be computed after any reflection to avoid the upper-right triangle.

        3. Y is not computed correctly.
        The value of y is computed as:
        y = (Y[j] * 0x1.0p63) + ((Y[j] - Y[j-1]) * (double)U2);

        (X[j],Y[j]) corresponds to the upper-left corner of the overhang rectangle. (X[j-1],Y[j-1]) corresponds to the lower-right corner.

        U1 is the distance to move right from the left side of the rectangle:
        x = (X[j] * 0x1.0p63) + ((X[j-1] - X[j]) * (double)U1)
        This is implemented correctly.

        U2 is the distance to move down from the top side of the rectangle.
        The original paper by McFarland used:
        y = (Y[j-1] * 0x1.0p63) + ((Y[j] - Y[j-1]) * (double)(0x1.0p63-U2));
        This effectively moves up from the bottom of the rectangle using 1-u2.

        The code is currently using:
        y = (Y[j] * 0x1.0p63) + ((Y[j] - Y[j-1]) * (double)U2);
        Since Y[j] is the top and Y[j-1] is the bottom this moves out of the rectangle. Any Y created by this will not be accepted. I ran a coverage tool on the exponential sampler method when issue 1 and 2 are fixed and only y is computed incorrectly. There is no coverage of the branch where y is below the curve.

        The code can be fixed to move down from the top of the rectangle:
        y = (Y[j] * 0x1.0p63) + ((Y[j-1] - Y[j]) * (double)U2);
        Where (Y[j-1] - Y[j]) is negative. This just swaps the indices used to compute the height of the rectangle region.

        Gaussian:
        1. Fix sampling when reflection occurs by computing x after the reflection. This is the same as #2 for the exponential.

        2. Fix the computation of y to be inside the rectangle. This is the same as #3 for the exponential.


        STEPS TO FOLLOW TO REPRODUCE THE PROBLEM :
        Use the new RandomGenerator interface to create samples. Histogram then with small histogram bins and inspect the histogram. It should trace the outline of the exponential/Gaussian PDF.


        ---------- BEGIN SOURCE ----------
        I extracted the RandomSupport class and ran it under Java 8 using SplittableRandom. This required updating RandomSupport to accept a SplittableRandom in-place of a RandomGenerator. This allows Apache Commons Math to be used to perform the Chi-square test. I do not know what the equivalent test environment should be for JDK 17.

        The testGaussianSamplesWithQuantiles function often passes. The others consistently fail.

        ---
        import org.junit.jupiter.api.Assertions;
        import org.junit.jupiter.api.Test;
        import java.util.Arrays;
        import java.util.SplittableRandom;
        import java.util.function.DoubleSupplier;
        import org.apache.commons.math3.distribution.AbstractRealDistribution;
        import org.apache.commons.math3.distribution.ExponentialDistribution;
        import org.apache.commons.math3.distribution.NormalDistribution;
        import org.apache.commons.math3.stat.inference.ChiSquareTest;

        /**
         * Test for {@link RandomSupport}.
         */
        class RandomSupportTest {

            /**
             * Test Gaussian samples using a large number of bins based on uniformly spaced quantiles.
             */
            @Test
            void testGaussianSamplesWithQuantiles() {
                final int bins = 2000;
                final NormalDistribution dist = new NormalDistribution(null, 0.0, 1.0);
                final double[] quantiles = new double[bins];
                for (int i = 0; i < bins; i++) {
                    quantiles[i] = dist.inverseCumulativeProbability((i + 1.0) / bins);
                }
                testSamples(quantiles, false);
            }

            /**
             * Test Gaussian samples using a large number of bins uniformly spaced in a range.
             */
            @Test
            void testGaussianSamplesWithUniformValues() {
                final int bins = 2000;
                final double[] values = new double[bins];
                final double minx = -8;
                final double maxx = 8;
                for (int i = 0; i < bins; i++) {
                    values[i] = minx + (maxx - minx) * (i + 1.0) / bins;
                }
                // Ensure upper bound is the support limit
                values[bins - 1] = Double.POSITIVE_INFINITY;
                testSamples(values, false);
            }

            /**
             * Test exponential samples using a large number of bins based on uniformly spaced quantiles.
             */
            @Test
            void testExponentialSamplesWithQuantiles() {
                final int bins = 2000;
                final ExponentialDistribution dist = new ExponentialDistribution(null, 1.0);
                final double[] quantiles = new double[bins];
                for (int i = 0; i < bins; i++) {
                    quantiles[i] = dist.inverseCumulativeProbability((i + 1.0) / bins);
                }
                testSamples(quantiles, true);
            }

            /**
             * Test exponential samples using a large number of bins uniformly spaced in a range.
             */
            @Test
            void testExponentialSamplesWithUniformValues() {
                final int bins = 2000;
                final double[] values = new double[bins];
                final double minx = 0;
                // Enter the tail of the distribution
                final double maxx = 12;
                for (int i = 0; i < bins; i++) {
                    values[i] = minx + (maxx - minx) * (i + 1.0) / bins;
                }
                // Ensure upper bound is the support limit
                values[bins - 1] = Double.POSITIVE_INFINITY;
                testSamples(values, true);
            }

            /**
             * Test samples using the provided bins. Values correspond to the bin upper limit. It
             * is assumed the values span most of the distribution. Additional tests are performed
             * using a region of the distribution sampled using the edge of the ziggurat.
             *
             * @param values Bin upper limits
             * @param exponential Set the true to use an exponential sampler
             */
            private static void testSamples(double[] values,
                                            boolean exponential) {
                final int bins = values.length;

                final int samples = 10000000;
                final long[] observed = new long[bins];
                final SplittableRandom rng = new SplittableRandom();
                final DoubleSupplier sampler = exponential ?
                    () -> RandomSupport.computeNextExponential(rng) :
                    () -> RandomSupport.computeNextGaussian(rng);
                for (int i = 0; i < samples; i++) {
                    final double x = sampler.getAsDouble();
                    final int index = findIndex(values, x);
                    observed[index]++;
                }

                // Compute expected
                final AbstractRealDistribution dist = exponential ?
                    new ExponentialDistribution(null, 1.0) : new NormalDistribution(null, 0.0, 1.0);
                final double[] expected = new double[bins];
                double x0 = Double.NEGATIVE_INFINITY;
                for (int i = 0; i < bins; i++) {
                    final double x1 = values[i];
                    expected[i] = dist.probability(x0, x1);
                    x0 = x1;
                }

                final double significanceLevel = 0.001;

                final double lowerBound = dist.getSupportLowerBound();

                final ChiSquareTest chiSquareTest = new ChiSquareTest();
                // Pass if we cannot reject null hypothesis that the distributions are the same.
                final double pValue = chiSquareTest.chiSquareTest(expected, observed);
                Assertions.assertFalse(pValue < 0.001,
                    () -> String.format("(%s <= x < %s) Chi-square p-value = %s",
                                        lowerBound, values[bins - 1], pValue));

                // Test bins sampled from the edge of the ziggurat. This is always around zero.
                for (final double range : new double[] {0.5, 0.25, 0.1, 0.05}) {
                    final int min = findIndex(values, -range);
                    final int max = findIndex(values, range);
                    final long[] observed2 = Arrays.copyOfRange(observed, min, max + 1);
                    final double[] expected2 = Arrays.copyOfRange(expected, min, max + 1);
                    final double pValue2 = chiSquareTest.chiSquareTest(expected2, observed2);
                    Assertions.assertFalse(pValue2 < significanceLevel,
                        () -> String.format("(%s <= x < %s) Chi-square p-value = %s",
                                            min == 0 ? lowerBound : values[min - 1], values[max], pValue2));
                }
            }

            /**
             * Find the index of the value in the data such that:
             * <pre>
             * data[index - 1] <= x < data[index]
             * </pre>
             *
             * <p>This is a specialised binary search that assumes the bounds of the data are the
             * extremes of the support, and the upper support is infinite. Thus an index cannot
             * be returned as equal to the data length.
             *
             * @param data the data
             * @param x the value
             * @return the index
             */
            private static int findIndex(double[] data, double x) {
                int low = 0;
                int high = data.length - 1;

                // Bracket so that low is just above the value x
                while (low <= high) {
                    final int mid = (low + high) >>> 1;
                    final double midVal = data[mid];

                    if (x < midVal) {
                        // Reduce search range
                        high = mid - 1;
                    } else {
                        // Set data[low] above the value
                        low = mid + 1;
                    }
                }
                // Verify the index is correct
                Assertions.assertTrue(x < data[low]);
                if (low != 0) {
                    Assertions.assertTrue(x >= data[low - 1]);
                }
                return low;
            }
        }

        ---------- END SOURCE ----------

        FREQUENCY : always


              gls Guy Steele (Inactive)
              webbuggrp Webbug Group
              Votes:
              0 Vote for this issue
              Watchers:
              9 Start watching this issue

                Created:
                Updated:
                Resolved: