C2: Missing Ideal optimizations for load and store vectors on SVE

XMLWordPrintable

    • aarch64
    • linux

      The issue seems to be a bad combination of mistakes made in JDK-8286941 (JDK20) combined with verification from JDK-8367389 (JDK26).

      JDK-8286941 has the consequence that we do not perform all Ideal optimizations. This is an issue on its own, because we do not optimize as much as we could.

      However, with JDK-8367389 we now encounter a particular case where we have multiple LoadVectorNodes that have separate MergeMem nodes, which we could have stepped over and gotten to the same memory state, but the mistakes from JDK-8286941 prevent this. Now we have multiple loads that SHOULD have the same memory state, but instead have different MergeMem nodes. This triggers the assert during SuperWord, introduced with JDK-8367389.

      ---------------------------------------------- ORIGINAL REPORT ------------------------------------------------------------

      When I running a JMH benchmark on an **AWS Graviton3 machine (with 256-bit sve support)**, the following assert failed. Log is as follow:

      ```
      # Run progress: 0.00% complete, ETA 00:00:14
      # Fork: 1 of 1
      WARNING: Using incubator modules: jdk.incubator.vector
      WARNING: A terminally deprecated method in sun.misc.Unsafe has been called
      WARNING: sun.misc.Unsafe::objectFieldOffset has been called by org.openjdk.jmh.util.Utils (file:/localhome/jadmin/erfang/jdk/build/linux-aarch64-server-fastdebug/images/test/micro/benchmarks.jar)
      WARNING: Please consider reporting this to the maintainers of class org.openjdk.jmh.util.Utils
      WARNING: sun.misc.Unsafe::objectFieldOffset will be removed in a future release
      # Warmup Iteration 1: #
      # A fatal error has been detected by the Java Runtime Environment:
      #
      # Internal Error (/localhome/jadmin/erfang/jdk/src/hotspot/share/opto/vectorization.cpp:231), pid=167298, tid=167336
      # assert(_inputs.at(alias_idx) == nullptr || _inputs.at(alias_idx) == load->in(1)) failed: not yet touched or the same input
      #
      # JRE version: OpenJDK Runtime Environment (26.0) (fastdebug build 26-internal-adhoc.jadmin.jdk)
      # Java VM: OpenJDK 64-Bit Server VM (fastdebug 26-internal-adhoc.jadmin.jdk, mixed mode, sharing, tiered, compressed oops, compressed class ptrs, g1 gc, linux-aarch64)
      # Problematic frame:
      # V [libjvm.so+0x1b2d5ec] VLoopMemorySlices::find_memory_slices()+0x29c
      #
      # Core dump will be written. Default location: Core dumps may be processed with "/usr/share/apport/apport -p%p -s%s -c%c -d%d -P%P -u%u -g%g -F%F -- %E" (or dumping to /localhome/jadmin/erfang/jdk/build/linux-aarch64-server-fastdebug/images/test/core.167298)
      #
      # An error report file with more information is saved as:
      # /localhome/jadmin/erfang/jdk/build/linux-aarch64-server-fastdebug/images/test/hs_err_pid167298.log
      ^C
      ERROR: Build failed for target 'test' in configuration 'linux-aarch64-server-fastdebug' (exit code 141)

      No indication of failed target found.
      HELP: Try searching the build log for '] Error'.
      HELP: Run 'make doctor' to diagnose build problems.

      make[1]: *** [/localhome/jadmin/erfang/jdk/make/Init.gmk:151: main] Error 141
      make: *** [/localhome/jadmin/erfang/jdk/make/PreInit.gmk:159: test] Interrupt
      ```

      Test cases is a new JMH benchmark file, you can put it in ```test/micro/org/openjdk/bench/jdk/incubator/vector/MaskLastTrueBenchmark.java```

      The code is as follow:
      ```
      /*
       * Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
       * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
       *
       * This code is free software; you can redistribute it and/or modify it
       * under the terms of the GNU General Public License version 2 only, as
       * published by the Free Software Foundation.
       *
       * This code is distributed in the hope that it will be useful, but WITHOUT
       * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
       * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License
       * version 2 for more details (a copy is included in the LICENSE file that
       * accompanied this code).
       *
       * You should have received a copy of the GNU General Public License version
       * 2 along with this work; if not, write to the Free Software Foundation,
       * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
       *
       * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
       * or visit www.oracle.com if you need additional information or have any
       * questions.
       */

      package org.openjdk.bench.jdk.incubator.vector;

      import java.util.Random;
      import jdk.incubator.vector.*;
      import java.util.concurrent.TimeUnit;
      import org.openjdk.jmh.annotations.*;

      @OutputTimeUnit(TimeUnit.MILLISECONDS)
      @State(Scope.Thread)
      @Warmup(iterations = 4, time = 2)
      @Measurement(iterations = 6, time = 1)
      @Fork(value = 1, jvmArgs = {"--add-modules=jdk.incubator.vector"})
      public class MaskLastTrueBenchmark {
          @Param({"128"})
          int size;

          private static final VectorSpecies<Byte> bspecies = VectorSpecies.ofLargestShape(byte.class);
          private static final VectorSpecies<Short> sspecies = VectorSpecies.ofLargestShape(short.class);
          private static final VectorSpecies<Integer> ispecies = VectorSpecies.ofLargestShape(int.class);
          private static final VectorSpecies<Long> lspecies = VectorSpecies.ofLargestShape(long.class);
          private static final VectorSpecies<Float> fspecies = VectorSpecies.ofLargestShape(float.class);
          private static final VectorSpecies<Double> dspecies = VectorSpecies.ofLargestShape(double.class);

          byte[] byte_arr;
          short[] short_arr;
          int[] int_arr;
          long[] long_arr;
          float[] float_arr;
          double[] double_arr;
          boolean[] mask_arr;

          @Setup(Level.Trial)
          public void BmSetup() {
              Random r = new Random();
              byte_arr = new byte[size];
              short_arr = new short[size];
              int_arr = new int[size];
              long_arr = new long[size];
              float_arr = new float[size];
              double_arr = new double[size];
              mask_arr = new boolean[size];

              for (int i = 0; i < size; i++) {
                  byte_arr[i] = (byte) r.nextInt();
                  short_arr[i] = (short) r.nextInt();
                  int_arr[i] = r.nextInt();
                  long_arr[i] = r.nextLong();
                  float_arr[i] = r.nextFloat();
                  double_arr[i] = r.nextDouble();
                  mask_arr[i] = r.nextBoolean();
              }
          }

          // VectorMask.fromArray + lastTrue

          @Benchmark
          public int testLastTrueFromArrayByte() {
              int sum = 0;
              for (int i = 0; i < size; i += bspecies.length()) {
                  VectorMask<Byte> m = VectorMask.fromArray(bspecies, mask_arr, i);
                  sum += m.lastTrue();
              }
              return sum;
          }

          @Benchmark
          public int testLastTrueFromArrayShort() {
              int sum = 0;
              for (int i = 0; i < size; i += sspecies.length()) {
                  VectorMask<Short> m = VectorMask.fromArray(sspecies, mask_arr, i);
                  sum += m.lastTrue();
              }
              return sum;
          }

          @Benchmark
          public int testLastTrueFromArrayInt() {
              int sum = 0;
              for (int i = 0; i < size; i += ispecies.length()) {
                  VectorMask<Integer> m = VectorMask.fromArray(ispecies, mask_arr, i);
                  sum += m.lastTrue();
              }
              return sum;
          }

          @Benchmark
          public int testLastTrueFromArrayLong() {
              int sum = 0;
              for (int i = 0; i < size; i += lspecies.length()) {
                  VectorMask<Long> m = VectorMask.fromArray(lspecies, mask_arr, i);
                  sum += m.lastTrue();
              }
              return sum;
          }

          @Benchmark
          public int testLastTrueFromArrayFloat() {
              int sum = 0;
              for (int i = 0; i < size; i += fspecies.length()) {
                  VectorMask<Float> m = VectorMask.fromArray(fspecies, mask_arr, i);
                  sum += m.lastTrue();
              }
              return sum;
          }

          @Benchmark
          public int testLastTrueFromArrayDouble() {
              int sum = 0;
              for (int i = 0; i < size; i += dspecies.length()) {
                  VectorMask<Double> m = VectorMask.fromArray(dspecies, mask_arr, i);
                  sum += m.lastTrue();
              }
              return sum;
          }


          // Vector.compare + lastTrue

          @Benchmark
          public int testLastTrueCompareByte() {
              int sum = 0;
              for (int i = 0; i < size; i += bspecies.length()) {
                  ByteVector v = ByteVector.fromArray(bspecies, byte_arr, i);
                  sum += v.compare(VectorOperators.LT, 0).lastTrue();
              }
              return sum;
          }

          @Benchmark
          public int testLastTrueCompareShort() {
              int sum = 0;
              for (int i = 0; i < size; i += sspecies.length()) {
                  ShortVector v = ShortVector.fromArray(sspecies, short_arr, i);
                  sum += v.compare(VectorOperators.LT, 0).lastTrue();
              }
              return sum;
          }

          @Benchmark
          public int testLastTrueCompareInt() {
              int sum = 0;
              for (int i = 0; i < size; i += ispecies.length()) {
                  IntVector v = IntVector.fromArray(ispecies, int_arr, i);
                  sum += v.compare(VectorOperators.LT, 0).lastTrue();
              }
              return sum;
          }

          @Benchmark
          public int testLastTrueCompareLong() {
              int sum = 0;
              for (int i = 0; i < size; i += lspecies.length()) {
                  LongVector v = LongVector.fromArray(lspecies, long_arr, i);
                  sum += v.compare(VectorOperators.LT, 0).lastTrue();
              }
              return sum;
          }

          @Benchmark
          public int testLastTrueCompareFloat() {
              int sum = 0;
              for (int i = 0; i < size; i += fspecies.length()) {
                  FloatVector v = FloatVector.fromArray(fspecies, float_arr, i);
                  sum += v.compare(VectorOperators.LT, 0).lastTrue();
              }
              return sum;
          }

          @Benchmark
          public int testLastTrueCompareDouble() {
              int sum = 0;
              for (int i = 0; i < size; i += dspecies.length()) {
                  DoubleVector v = DoubleVector.fromArray(dspecies, double_arr, i);
                  sum += v.compare(VectorOperators.LT, 0).lastTrue();
              }
              return sum;
          }


          // VectorMask.indexInRange + lastTrue

          @Benchmark
          public int testLastTrueIndexInRangeByte() {
              int sum = 0;
              int limit = 0;
              VectorMask<Byte> m = VectorMask.fromArray(bspecies, mask_arr, 0);
              for (int i = 0; i < size; i += bspecies.length()) {
                  sum += m.indexInRange(0, limit++ % (m.length())).lastTrue();
              }
              return sum;
          }

          @Benchmark
          public int testLastTrueIndexInRangeShort() {
              int sum = 0;
              int limit = 0;
              VectorMask<Short> m = VectorMask.fromArray(sspecies, mask_arr, 0);
              for (int i = 0; i < size; i += sspecies.length()) {
                  sum += m.indexInRange(0, limit++ % (m.length())).lastTrue();
              }
              return sum;
          }

          @Benchmark
          public int testLastTrueIndexInRangeInt() {
              int sum = 0;
              int limit = 0;
              VectorMask<Integer> m = VectorMask.fromArray(ispecies, mask_arr, 0);
              for (int i = 0; i < size; i += ispecies.length()) {
                  sum += m.indexInRange(0, limit++ % (m.length())).lastTrue();
              }
              return sum;
          }

          @Benchmark
          public int testLastTrueIndexInRangeLong() {
              int sum = 0;
              int limit = 0;
              VectorMask<Long> m = VectorMask.fromArray(lspecies, mask_arr, 0);
              for (int i = 0; i < size; i += lspecies.length()) {
                  sum += m.indexInRange(0, limit++ % (m.length())).lastTrue();
              }
              return sum;
          }

          @Benchmark
          public int testLastTrueIndexInRangeFloat() {
              int sum = 0;
              int limit = 0;
              VectorMask<Float> m = VectorMask.fromArray(fspecies, mask_arr, 0);
              for (int i = 0; i < size; i += fspecies.length()) {
                  sum += m.indexInRange(0, limit++ % (m.length())).lastTrue();
              }
              return sum;
          }

          @Benchmark
          public int testLastTrueIndexInRangeDouble() {
              int sum = 0;
              int limit = 0;
              VectorMask<Double> m = VectorMask.fromArray(dspecies, mask_arr, 0);
              for (int i = 0; i < size; i += dspecies.length()) {
                  sum += m.indexInRange(0, limit++ % (m.length())).lastTrue();
              }
              return sum;
          }


          // VectorMask.fromLong + lastTrue

          @Benchmark
          public int testLastTrueFromLongByte() {
              int sum = 0;
              for (int i = 0; i < size; i += bspecies.length()) {
                  VectorMask<Byte> m = VectorMask.fromLong(bspecies, i);
                  sum += m.lastTrue();
              }
              return sum;
          }

          @Benchmark
          public int testLastTrueFromLongShort() {
              int sum = 0;
              for (int i = 0; i < size; i += sspecies.length()) {
                  VectorMask<Short> m = VectorMask.fromLong(sspecies, i);
                  sum += m.lastTrue();
              }
              return sum;
          }

          @Benchmark
          public int testLastTrueFromLongInt() {
              int sum = 0;
              for (int i = 0; i < size; i += ispecies.length()) {
                  VectorMask<Integer> m = VectorMask.fromLong(ispecies, i);
                  sum += m.lastTrue();
              }
              return sum;
          }

          @Benchmark
          public int testLastTrueFromLongLong() {
              int sum = 0;
              for (int i = 0; i < size; i += lspecies.length()) {
                  VectorMask<Long> m = VectorMask.fromLong(lspecies, i);
                  sum += m.lastTrue();
              }
              return sum;
          }

          @Benchmark
          public int testLastTrueFromLongFloat() {
              int sum = 0;
              for (int i = 0; i < size; i += fspecies.length()) {
                  VectorMask<Float> m = VectorMask.fromLong(fspecies, i);
                  sum += m.lastTrue();
              }
              return sum;
          }

          @Benchmark
          public int testLastTrueFromLongDouble() {
              int sum = 0;
              for (int i = 0; i < size; i += dspecies.length()) {
                  VectorMask<Double> m = VectorMask.fromLong(dspecies, i);
                  sum += m.lastTrue();
              }
              return sum;
          }


          // VectorMask.fromArray + lastTrue & toLong
          // Before:
          // LoadVector + VectorLoadMask + VectorMaskLastTrue
          // + VectorMaskToLong
          // After:
          // LoadVector + VectorMaskLastTrue
          // + VectorLoadMask + VectorMaskToLong
          //
          // Match rule of "LoadVector + VectorLoadMask" doesn't match since LoadVector is multi used.

          @Benchmark
          public long testMultiUsesFromArrayByte() {
              long sum = 0;
              for (int i = 0; i < size; i += bspecies.length()) {
                  VectorMask<Byte> m = VectorMask.fromArray(bspecies, mask_arr, i);
                  sum += m.lastTrue();
                  sum += m.toLong();
              }
              return sum;
          }

          @Benchmark
          public long testMultiUsesFromArrayShort() {
              long sum = 0;
              for (int i = 0; i < size; i += sspecies.length()) {
                  VectorMask<Short> m = VectorMask.fromArray(sspecies, mask_arr, i);
                  sum += m.lastTrue();
                  sum += m.toLong();
              }
              return sum;
          }

          @Benchmark
          public long testMultiUsesFromArrayInt() {
              long sum = 0;
              for (int i = 0; i < size; i += ispecies.length()) {
                  VectorMask<Integer> m = VectorMask.fromArray(ispecies, mask_arr, i);
                  sum += m.lastTrue();
                  sum += m.toLong();
              }
              return sum;
          }

          @Benchmark
          public long testMultiUsesFromArrayLong() {
              long sum = 0;
              for (int i = 0; i < size; i += lspecies.length()) {
                  VectorMask<Long> m = VectorMask.fromArray(lspecies, mask_arr, i);
                  sum += m.lastTrue();
                  sum += m.toLong();
              }
              return sum;
          }

          @Benchmark
          public long testMultiUsesFromArrayFloat() {
              long sum = 0;
              for (int i = 0; i < size; i += fspecies.length()) {
                  VectorMask<Float> m = VectorMask.fromArray(fspecies, mask_arr, i);
                  sum += m.lastTrue();
                  sum += m.toLong();
              }
              return sum;
          }

          @Benchmark
          public long testMultiUsesFromArrayDouble() {
              long sum = 0;
              for (int i = 0; i < size; i += dspecies.length()) {
                  VectorMask<Double> m = VectorMask.fromArray(dspecies, mask_arr, i);
                  sum += m.lastTrue();
                  sum += m.toLong();
              }
              return sum;
          }


          // Vector.compare + lastTrue & toLong

          @Benchmark
          public long testMultiUsesCompareByte() {
              long sum = 0;
              for (int i = 0; i < size; i += bspecies.length()) {
                  ByteVector v = ByteVector.fromArray(bspecies, byte_arr, i);
                  VectorMask<Byte> m = v.compare(VectorOperators.LT, 0);
                  sum += m.lastTrue();
                  sum += m.toLong();
              }
              return sum;
          }

          @Benchmark
          public long testMultiUsesCompareShort() {
              long sum = 0;
              for (int i = 0; i < size; i += sspecies.length()) {
                  ShortVector v = ShortVector.fromArray(sspecies, short_arr, i);
                  VectorMask<Short> m = v.compare(VectorOperators.LT, 0);
                  sum += m.lastTrue();
                  sum += m.toLong();
              }
              return sum;
          }

          @Benchmark
          public long testMultiUsesCompareInt() {
              long sum = 0;
              for (int i = 0; i < size; i += ispecies.length()) {
                  IntVector v = IntVector.fromArray(ispecies, int_arr, i);
                  VectorMask<Integer> m = v.compare(VectorOperators.LT, 0);
                  sum += m.lastTrue();
                  sum += m.toLong();
              }
              return sum;
          }

          @Benchmark
          public long testMultiUsesCompareLong() {
              long sum = 0;
              for (int i = 0; i < size; i += lspecies.length()) {
                  LongVector v = LongVector.fromArray(lspecies, long_arr, i);
                  VectorMask<Long> m = v.compare(VectorOperators.LT, 0);
                  sum += m.lastTrue();
                  sum += m.toLong();
              }
              return sum;
          }

          @Benchmark
          public long testMultiUsesCompareFloat() {
              long sum = 0;
              for (int i = 0; i < size; i += fspecies.length()) {
                  FloatVector v = FloatVector.fromArray(fspecies, float_arr, i);
                  VectorMask<Float> m = v.compare(VectorOperators.LT, 0);
                  sum += m.lastTrue();
                  sum += m.toLong();
              }
              return sum;
          }

          @Benchmark
          public long testMultiUsesCompareDouble() {
              long sum = 0;
              for (int i = 0; i < size; i += dspecies.length()) {
                  DoubleVector v = DoubleVector.fromArray(dspecies, double_arr, i);
                  VectorMask<Double> m = v.compare(VectorOperators.LT, 0);
                  sum += m.lastTrue();
                  sum += m.toLong();
              }
              return sum;
          }

      }
      ```

      To reproduce the crash, run the following test command:
      ``` make test TEST=micro:org.openjdk.bench.jdk.incubator.vector.MaskLastTrueBenchmark.* ```

      [~epeter] Would you mind taking a look since the assert was introduced by https://github.com/openjdk/jdk/commit/2ac24bf1bac9c32704ebd72b93a75819b9404063, thanks!

        1. hs_err_pid144663.log
          91 kB
        2. hs_err_pid186844.log
          115 kB
        3. igv-dump.txt
          116 kB
        4. Int256VectorTests.java
          6 kB
        5. replay_pid144663.log
          277 kB
        6. Test1.java
          4 kB
        7. TestOptimizeLoadVector.java
          2 kB
        8. TestOptimizeStoreVector.java
          1 kB
        9. tmp.xml
          302 kB

            Assignee:
            Xiaohong Gong
            Reporter:
            Eric Fang
            Votes:
            0 Vote for this issue
            Watchers:
            6 Start watching this issue

              Created:
              Updated: