C2: Missing Ideal optimizations for load and store vectors on SVE

XMLWordPrintable

    • b02
    • aarch64
    • linux

        The issue seems to be a bad combination of mistakes made in JDK-8286941 (JDK20) combined with verification from JDK-8367389 (JDK26).

        JDK-8286941 has the consequence that we do not perform all Ideal optimizations. This is an issue on its own, because we do not optimize as much as we could.

        However, with JDK-8367389 we now encounter a particular case where we have multiple LoadVectorNodes that have separate MergeMem nodes, which we could have stepped over and gotten to the same memory state, but the mistakes from JDK-8286941 prevent this. Now we have multiple loads that SHOULD have the same memory state, but instead have different MergeMem nodes. This triggers the assert during SuperWord, introduced with JDK-8367389.

        ---------------------------------------------- ORIGINAL REPORT ------------------------------------------------------------

        When I running a JMH benchmark on an **AWS Graviton3 machine (with 256-bit sve support)**, the following assert failed. Log is as follow:

        ```
        # Run progress: 0.00% complete, ETA 00:00:14
        # Fork: 1 of 1
        WARNING: Using incubator modules: jdk.incubator.vector
        WARNING: A terminally deprecated method in sun.misc.Unsafe has been called
        WARNING: sun.misc.Unsafe::objectFieldOffset has been called by org.openjdk.jmh.util.Utils (file:/localhome/jadmin/erfang/jdk/build/linux-aarch64-server-fastdebug/images/test/micro/benchmarks.jar)
        WARNING: Please consider reporting this to the maintainers of class org.openjdk.jmh.util.Utils
        WARNING: sun.misc.Unsafe::objectFieldOffset will be removed in a future release
        # Warmup Iteration 1: #
        # A fatal error has been detected by the Java Runtime Environment:
        #
        # Internal Error (/localhome/jadmin/erfang/jdk/src/hotspot/share/opto/vectorization.cpp:231), pid=167298, tid=167336
        # assert(_inputs.at(alias_idx) == nullptr || _inputs.at(alias_idx) == load->in(1)) failed: not yet touched or the same input
        #
        # JRE version: OpenJDK Runtime Environment (26.0) (fastdebug build 26-internal-adhoc.jadmin.jdk)
        # Java VM: OpenJDK 64-Bit Server VM (fastdebug 26-internal-adhoc.jadmin.jdk, mixed mode, sharing, tiered, compressed oops, compressed class ptrs, g1 gc, linux-aarch64)
        # Problematic frame:
        # V [libjvm.so+0x1b2d5ec] VLoopMemorySlices::find_memory_slices()+0x29c
        #
        # Core dump will be written. Default location: Core dumps may be processed with "/usr/share/apport/apport -p%p -s%s -c%c -d%d -P%P -u%u -g%g -F%F -- %E" (or dumping to /localhome/jadmin/erfang/jdk/build/linux-aarch64-server-fastdebug/images/test/core.167298)
        #
        # An error report file with more information is saved as:
        # /localhome/jadmin/erfang/jdk/build/linux-aarch64-server-fastdebug/images/test/hs_err_pid167298.log
        ^C
        ERROR: Build failed for target 'test' in configuration 'linux-aarch64-server-fastdebug' (exit code 141)

        No indication of failed target found.
        HELP: Try searching the build log for '] Error'.
        HELP: Run 'make doctor' to diagnose build problems.

        make[1]: *** [/localhome/jadmin/erfang/jdk/make/Init.gmk:151: main] Error 141
        make: *** [/localhome/jadmin/erfang/jdk/make/PreInit.gmk:159: test] Interrupt
        ```

        Test cases is a new JMH benchmark file, you can put it in ```test/micro/org/openjdk/bench/jdk/incubator/vector/MaskLastTrueBenchmark.java```

        The code is as follow:
        ```
        /*
         * Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
         * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
         *
         * This code is free software; you can redistribute it and/or modify it
         * under the terms of the GNU General Public License version 2 only, as
         * published by the Free Software Foundation.
         *
         * This code is distributed in the hope that it will be useful, but WITHOUT
         * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
         * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License
         * version 2 for more details (a copy is included in the LICENSE file that
         * accompanied this code).
         *
         * You should have received a copy of the GNU General Public License version
         * 2 along with this work; if not, write to the Free Software Foundation,
         * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
         *
         * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
         * or visit www.oracle.com if you need additional information or have any
         * questions.
         */

        package org.openjdk.bench.jdk.incubator.vector;

        import java.util.Random;
        import jdk.incubator.vector.*;
        import java.util.concurrent.TimeUnit;
        import org.openjdk.jmh.annotations.*;

        @OutputTimeUnit(TimeUnit.MILLISECONDS)
        @State(Scope.Thread)
        @Warmup(iterations = 4, time = 2)
        @Measurement(iterations = 6, time = 1)
        @Fork(value = 1, jvmArgs = {"--add-modules=jdk.incubator.vector"})
        public class MaskLastTrueBenchmark {
            @Param({"128"})
            int size;

            private static final VectorSpecies<Byte> bspecies = VectorSpecies.ofLargestShape(byte.class);
            private static final VectorSpecies<Short> sspecies = VectorSpecies.ofLargestShape(short.class);
            private static final VectorSpecies<Integer> ispecies = VectorSpecies.ofLargestShape(int.class);
            private static final VectorSpecies<Long> lspecies = VectorSpecies.ofLargestShape(long.class);
            private static final VectorSpecies<Float> fspecies = VectorSpecies.ofLargestShape(float.class);
            private static final VectorSpecies<Double> dspecies = VectorSpecies.ofLargestShape(double.class);

            byte[] byte_arr;
            short[] short_arr;
            int[] int_arr;
            long[] long_arr;
            float[] float_arr;
            double[] double_arr;
            boolean[] mask_arr;

            @Setup(Level.Trial)
            public void BmSetup() {
                Random r = new Random();
                byte_arr = new byte[size];
                short_arr = new short[size];
                int_arr = new int[size];
                long_arr = new long[size];
                float_arr = new float[size];
                double_arr = new double[size];
                mask_arr = new boolean[size];

                for (int i = 0; i < size; i++) {
                    byte_arr[i] = (byte) r.nextInt();
                    short_arr[i] = (short) r.nextInt();
                    int_arr[i] = r.nextInt();
                    long_arr[i] = r.nextLong();
                    float_arr[i] = r.nextFloat();
                    double_arr[i] = r.nextDouble();
                    mask_arr[i] = r.nextBoolean();
                }
            }

            // VectorMask.fromArray + lastTrue

            @Benchmark
            public int testLastTrueFromArrayByte() {
                int sum = 0;
                for (int i = 0; i < size; i += bspecies.length()) {
                    VectorMask<Byte> m = VectorMask.fromArray(bspecies, mask_arr, i);
                    sum += m.lastTrue();
                }
                return sum;
            }

            @Benchmark
            public int testLastTrueFromArrayShort() {
                int sum = 0;
                for (int i = 0; i < size; i += sspecies.length()) {
                    VectorMask<Short> m = VectorMask.fromArray(sspecies, mask_arr, i);
                    sum += m.lastTrue();
                }
                return sum;
            }

            @Benchmark
            public int testLastTrueFromArrayInt() {
                int sum = 0;
                for (int i = 0; i < size; i += ispecies.length()) {
                    VectorMask<Integer> m = VectorMask.fromArray(ispecies, mask_arr, i);
                    sum += m.lastTrue();
                }
                return sum;
            }

            @Benchmark
            public int testLastTrueFromArrayLong() {
                int sum = 0;
                for (int i = 0; i < size; i += lspecies.length()) {
                    VectorMask<Long> m = VectorMask.fromArray(lspecies, mask_arr, i);
                    sum += m.lastTrue();
                }
                return sum;
            }

            @Benchmark
            public int testLastTrueFromArrayFloat() {
                int sum = 0;
                for (int i = 0; i < size; i += fspecies.length()) {
                    VectorMask<Float> m = VectorMask.fromArray(fspecies, mask_arr, i);
                    sum += m.lastTrue();
                }
                return sum;
            }

            @Benchmark
            public int testLastTrueFromArrayDouble() {
                int sum = 0;
                for (int i = 0; i < size; i += dspecies.length()) {
                    VectorMask<Double> m = VectorMask.fromArray(dspecies, mask_arr, i);
                    sum += m.lastTrue();
                }
                return sum;
            }


            // Vector.compare + lastTrue

            @Benchmark
            public int testLastTrueCompareByte() {
                int sum = 0;
                for (int i = 0; i < size; i += bspecies.length()) {
                    ByteVector v = ByteVector.fromArray(bspecies, byte_arr, i);
                    sum += v.compare(VectorOperators.LT, 0).lastTrue();
                }
                return sum;
            }

            @Benchmark
            public int testLastTrueCompareShort() {
                int sum = 0;
                for (int i = 0; i < size; i += sspecies.length()) {
                    ShortVector v = ShortVector.fromArray(sspecies, short_arr, i);
                    sum += v.compare(VectorOperators.LT, 0).lastTrue();
                }
                return sum;
            }

            @Benchmark
            public int testLastTrueCompareInt() {
                int sum = 0;
                for (int i = 0; i < size; i += ispecies.length()) {
                    IntVector v = IntVector.fromArray(ispecies, int_arr, i);
                    sum += v.compare(VectorOperators.LT, 0).lastTrue();
                }
                return sum;
            }

            @Benchmark
            public int testLastTrueCompareLong() {
                int sum = 0;
                for (int i = 0; i < size; i += lspecies.length()) {
                    LongVector v = LongVector.fromArray(lspecies, long_arr, i);
                    sum += v.compare(VectorOperators.LT, 0).lastTrue();
                }
                return sum;
            }

            @Benchmark
            public int testLastTrueCompareFloat() {
                int sum = 0;
                for (int i = 0; i < size; i += fspecies.length()) {
                    FloatVector v = FloatVector.fromArray(fspecies, float_arr, i);
                    sum += v.compare(VectorOperators.LT, 0).lastTrue();
                }
                return sum;
            }

            @Benchmark
            public int testLastTrueCompareDouble() {
                int sum = 0;
                for (int i = 0; i < size; i += dspecies.length()) {
                    DoubleVector v = DoubleVector.fromArray(dspecies, double_arr, i);
                    sum += v.compare(VectorOperators.LT, 0).lastTrue();
                }
                return sum;
            }


            // VectorMask.indexInRange + lastTrue

            @Benchmark
            public int testLastTrueIndexInRangeByte() {
                int sum = 0;
                int limit = 0;
                VectorMask<Byte> m = VectorMask.fromArray(bspecies, mask_arr, 0);
                for (int i = 0; i < size; i += bspecies.length()) {
                    sum += m.indexInRange(0, limit++ % (m.length())).lastTrue();
                }
                return sum;
            }

            @Benchmark
            public int testLastTrueIndexInRangeShort() {
                int sum = 0;
                int limit = 0;
                VectorMask<Short> m = VectorMask.fromArray(sspecies, mask_arr, 0);
                for (int i = 0; i < size; i += sspecies.length()) {
                    sum += m.indexInRange(0, limit++ % (m.length())).lastTrue();
                }
                return sum;
            }

            @Benchmark
            public int testLastTrueIndexInRangeInt() {
                int sum = 0;
                int limit = 0;
                VectorMask<Integer> m = VectorMask.fromArray(ispecies, mask_arr, 0);
                for (int i = 0; i < size; i += ispecies.length()) {
                    sum += m.indexInRange(0, limit++ % (m.length())).lastTrue();
                }
                return sum;
            }

            @Benchmark
            public int testLastTrueIndexInRangeLong() {
                int sum = 0;
                int limit = 0;
                VectorMask<Long> m = VectorMask.fromArray(lspecies, mask_arr, 0);
                for (int i = 0; i < size; i += lspecies.length()) {
                    sum += m.indexInRange(0, limit++ % (m.length())).lastTrue();
                }
                return sum;
            }

            @Benchmark
            public int testLastTrueIndexInRangeFloat() {
                int sum = 0;
                int limit = 0;
                VectorMask<Float> m = VectorMask.fromArray(fspecies, mask_arr, 0);
                for (int i = 0; i < size; i += fspecies.length()) {
                    sum += m.indexInRange(0, limit++ % (m.length())).lastTrue();
                }
                return sum;
            }

            @Benchmark
            public int testLastTrueIndexInRangeDouble() {
                int sum = 0;
                int limit = 0;
                VectorMask<Double> m = VectorMask.fromArray(dspecies, mask_arr, 0);
                for (int i = 0; i < size; i += dspecies.length()) {
                    sum += m.indexInRange(0, limit++ % (m.length())).lastTrue();
                }
                return sum;
            }


            // VectorMask.fromLong + lastTrue

            @Benchmark
            public int testLastTrueFromLongByte() {
                int sum = 0;
                for (int i = 0; i < size; i += bspecies.length()) {
                    VectorMask<Byte> m = VectorMask.fromLong(bspecies, i);
                    sum += m.lastTrue();
                }
                return sum;
            }

            @Benchmark
            public int testLastTrueFromLongShort() {
                int sum = 0;
                for (int i = 0; i < size; i += sspecies.length()) {
                    VectorMask<Short> m = VectorMask.fromLong(sspecies, i);
                    sum += m.lastTrue();
                }
                return sum;
            }

            @Benchmark
            public int testLastTrueFromLongInt() {
                int sum = 0;
                for (int i = 0; i < size; i += ispecies.length()) {
                    VectorMask<Integer> m = VectorMask.fromLong(ispecies, i);
                    sum += m.lastTrue();
                }
                return sum;
            }

            @Benchmark
            public int testLastTrueFromLongLong() {
                int sum = 0;
                for (int i = 0; i < size; i += lspecies.length()) {
                    VectorMask<Long> m = VectorMask.fromLong(lspecies, i);
                    sum += m.lastTrue();
                }
                return sum;
            }

            @Benchmark
            public int testLastTrueFromLongFloat() {
                int sum = 0;
                for (int i = 0; i < size; i += fspecies.length()) {
                    VectorMask<Float> m = VectorMask.fromLong(fspecies, i);
                    sum += m.lastTrue();
                }
                return sum;
            }

            @Benchmark
            public int testLastTrueFromLongDouble() {
                int sum = 0;
                for (int i = 0; i < size; i += dspecies.length()) {
                    VectorMask<Double> m = VectorMask.fromLong(dspecies, i);
                    sum += m.lastTrue();
                }
                return sum;
            }


            // VectorMask.fromArray + lastTrue & toLong
            // Before:
            // LoadVector + VectorLoadMask + VectorMaskLastTrue
            // + VectorMaskToLong
            // After:
            // LoadVector + VectorMaskLastTrue
            // + VectorLoadMask + VectorMaskToLong
            //
            // Match rule of "LoadVector + VectorLoadMask" doesn't match since LoadVector is multi used.

            @Benchmark
            public long testMultiUsesFromArrayByte() {
                long sum = 0;
                for (int i = 0; i < size; i += bspecies.length()) {
                    VectorMask<Byte> m = VectorMask.fromArray(bspecies, mask_arr, i);
                    sum += m.lastTrue();
                    sum += m.toLong();
                }
                return sum;
            }

            @Benchmark
            public long testMultiUsesFromArrayShort() {
                long sum = 0;
                for (int i = 0; i < size; i += sspecies.length()) {
                    VectorMask<Short> m = VectorMask.fromArray(sspecies, mask_arr, i);
                    sum += m.lastTrue();
                    sum += m.toLong();
                }
                return sum;
            }

            @Benchmark
            public long testMultiUsesFromArrayInt() {
                long sum = 0;
                for (int i = 0; i < size; i += ispecies.length()) {
                    VectorMask<Integer> m = VectorMask.fromArray(ispecies, mask_arr, i);
                    sum += m.lastTrue();
                    sum += m.toLong();
                }
                return sum;
            }

            @Benchmark
            public long testMultiUsesFromArrayLong() {
                long sum = 0;
                for (int i = 0; i < size; i += lspecies.length()) {
                    VectorMask<Long> m = VectorMask.fromArray(lspecies, mask_arr, i);
                    sum += m.lastTrue();
                    sum += m.toLong();
                }
                return sum;
            }

            @Benchmark
            public long testMultiUsesFromArrayFloat() {
                long sum = 0;
                for (int i = 0; i < size; i += fspecies.length()) {
                    VectorMask<Float> m = VectorMask.fromArray(fspecies, mask_arr, i);
                    sum += m.lastTrue();
                    sum += m.toLong();
                }
                return sum;
            }

            @Benchmark
            public long testMultiUsesFromArrayDouble() {
                long sum = 0;
                for (int i = 0; i < size; i += dspecies.length()) {
                    VectorMask<Double> m = VectorMask.fromArray(dspecies, mask_arr, i);
                    sum += m.lastTrue();
                    sum += m.toLong();
                }
                return sum;
            }


            // Vector.compare + lastTrue & toLong

            @Benchmark
            public long testMultiUsesCompareByte() {
                long sum = 0;
                for (int i = 0; i < size; i += bspecies.length()) {
                    ByteVector v = ByteVector.fromArray(bspecies, byte_arr, i);
                    VectorMask<Byte> m = v.compare(VectorOperators.LT, 0);
                    sum += m.lastTrue();
                    sum += m.toLong();
                }
                return sum;
            }

            @Benchmark
            public long testMultiUsesCompareShort() {
                long sum = 0;
                for (int i = 0; i < size; i += sspecies.length()) {
                    ShortVector v = ShortVector.fromArray(sspecies, short_arr, i);
                    VectorMask<Short> m = v.compare(VectorOperators.LT, 0);
                    sum += m.lastTrue();
                    sum += m.toLong();
                }
                return sum;
            }

            @Benchmark
            public long testMultiUsesCompareInt() {
                long sum = 0;
                for (int i = 0; i < size; i += ispecies.length()) {
                    IntVector v = IntVector.fromArray(ispecies, int_arr, i);
                    VectorMask<Integer> m = v.compare(VectorOperators.LT, 0);
                    sum += m.lastTrue();
                    sum += m.toLong();
                }
                return sum;
            }

            @Benchmark
            public long testMultiUsesCompareLong() {
                long sum = 0;
                for (int i = 0; i < size; i += lspecies.length()) {
                    LongVector v = LongVector.fromArray(lspecies, long_arr, i);
                    VectorMask<Long> m = v.compare(VectorOperators.LT, 0);
                    sum += m.lastTrue();
                    sum += m.toLong();
                }
                return sum;
            }

            @Benchmark
            public long testMultiUsesCompareFloat() {
                long sum = 0;
                for (int i = 0; i < size; i += fspecies.length()) {
                    FloatVector v = FloatVector.fromArray(fspecies, float_arr, i);
                    VectorMask<Float> m = v.compare(VectorOperators.LT, 0);
                    sum += m.lastTrue();
                    sum += m.toLong();
                }
                return sum;
            }

            @Benchmark
            public long testMultiUsesCompareDouble() {
                long sum = 0;
                for (int i = 0; i < size; i += dspecies.length()) {
                    DoubleVector v = DoubleVector.fromArray(dspecies, double_arr, i);
                    VectorMask<Double> m = v.compare(VectorOperators.LT, 0);
                    sum += m.lastTrue();
                    sum += m.toLong();
                }
                return sum;
            }

        }
        ```

        To reproduce the crash, run the following test command:
        ``` make test TEST=micro:org.openjdk.bench.jdk.incubator.vector.MaskLastTrueBenchmark.* ```

        [~epeter] Would you mind taking a look since the assert was introduced by https://github.com/openjdk/jdk/commit/2ac24bf1bac9c32704ebd72b93a75819b9404063, thanks!

          1. tmp.xml
            302 kB
          2. TestOptimizeStoreVector.java
            1 kB
          3. TestOptimizeLoadVector.java
            2 kB
          4. Test1.java
            4 kB
          5. replay_pid144663.log
            277 kB
          6. Int256VectorTests.java
            6 kB
          7. igv-dump.txt
            116 kB
          8. hs_err_pid186844.log
            115 kB
          9. hs_err_pid144663.log
            91 kB

              Assignee:
              Xiaohong Gong
              Reporter:
              Eric Fang
              Votes:
              0 Vote for this issue
              Watchers:
              7 Start watching this issue

                Created:
                Updated:
                Resolved: