/*

wget https://repo1.maven.org/maven2/org/ow2/asm/asm/9.8/asm-9.8.jar
mkdir -p toctou
javac -d toctou -cp asm-9.8.jar TocTouTest.java
java  -cp toctou:asm-9.8.jar toctou.TocTouTest

*/


package toctou;

import org.objectweb.asm.ClassReader;
import org.objectweb.asm.ClassWriter;

import java.io.ByteArrayInputStream;
import java.io.DataInputStream;
import java.lang.foreign.Arena;
import java.lang.foreign.MemorySegment;
import java.lang.foreign.ValueLayout;
import java.lang.reflect.Method;
import java.util.concurrent.ConcurrentLinkedDeque;
import java.util.concurrent.atomic.AtomicInteger;

import static org.objectweb.asm.Opcodes.*;

public class TocTouTest {

    public static void main(String[] args) throws Exception {
        var cw = new ClassWriter(0);
        cw.visit(V17, ACC_PUBLIC, "toctou/Test", null, "java/lang/Object", null);
        var string = "corrupted_strin!".replace('!', 'g');
        var corruptedString = '\0' + string.substring(1);
        var mv = cw.visitMethod(ACC_PUBLIC | ACC_STATIC, string, "()Ljava/lang/String;", null, null);
        mv.visitCode();
        mv.visitLdcInsn(string);
        mv.visitInsn(ARETURN);
        mv.visitMaxs(1, 0);
        mv.visitEnd();
        cw.visitEnd();
        byte[] bytes = cw.toByteArray();
        var cr = new ClassReader(bytes);
        int byteOffset;

        lookup:
        {
            for (int i = 1; i < cr.getItemCount(); i++) {
                int offset = cr.getItem(i);
                int tag = bytes[offset - 1] & 0xFF;
                if (tag != 1) {
                    continue;
                }
                if (new DataInputStream(new ByteArrayInputStream(bytes, offset, bytes.length - offset)).readUTF().equals(string)) {
                    byteOffset = offset + 2;
                    break lookup;
                }
            }
            throw new IllegalStateException();
        }

        record Entry(MemorySegment segment, Arena arena, AtomicInteger uses) {

            void free() {
                if (uses.decrementAndGet() == 0) {
                    arena.close();
                }
            }
        }
        var queue = new ConcurrentLinkedDeque<Entry>();
        var thread = Thread.ofPlatform().daemon().start(() -> {
            do {
                Entry entry;
                if ((entry = queue.poll()) != null) {
                    try {
                        entry.segment.set(ValueLayout.JAVA_BYTE, byteOffset, (byte) 0);
                    } finally {
                        entry.free();
                    }
                }
            } while (!Thread.interrupted());
        });

        class Loader extends ClassLoader {

            Class<?> define() {
                var arena = Arena.ofShared();
                var counter = new AtomicInteger(2);
                var segment = arena.allocate(bytes.length);
                MemorySegment.copy(
                        MemorySegment.ofArray(bytes),
                        0,
                        segment,
                        0,
                        bytes.length
                );
                queue.offer(new Entry(segment, arena, counter));
                try {
                    return defineClass(null, segment.asByteBuffer(), null);
                } finally {
                    if (counter.decrementAndGet() == 0) {
                        arena.close();
                    }
                }
            }
        }
        var loader = new Loader();
        for (; ; ) {
            Class<?> c;
            try {
                c = loader.define();
                loader = new Loader();
            } catch (ClassFormatError ignored) {
                continue;
            }
            Method method;
            try {
                method = c.getDeclaredMethod(string);
            } catch (NoSuchMethodException ignored) {
                method = c.getDeclaredMethod(corruptedString);
            }
            String myString;
            myString = (String) method.invoke(null);
            if (!string.equals(myString)) {
                System.out.println(myString);
            }
        }
    }
}
