import java.lang.reflect.Field;
import java.lang.reflect.Method;
import java.util.LinkedList;
import java.util.List;

/**
 * Demonstrate performance edge case in ThreadLocalMap.
 * Run with: `java --add-opens=java.base/java.lang=ALL-UNNAMED ThreadLocalTest.java`
 * Observe - sometimes a long run of entries will build up in the hash table, and time to remove
 * one element will degrade massively, e.g.
 * ```
 * Length: 524288, numEntries: 319309, maxRun: 18449, time to remove one element: 672.975ms
 * ```
 * Seems to be sensitive to GC behaviour, only observed an issue so far with G1.
 */
public class ThreadLocalTest {

    static List<ThreadLocal<Integer>> retained = new LinkedList<>();

    public static void main(String[] args) throws Exception {
        // Seed some entries for removal speed tests
        addEntries(1000, true);

        // Main churn loop
        for (int i = 0; i < 1_000_000; i++) {
            // add some unretained entries to be garbage collected later
            addEntries(100, false);

            // add and remove some
            for (int j = 0; j < 100; j++) {
                ThreadLocal<Integer> tl = new ThreadLocal<>();
                tl.set(123); tl.remove();
            }

            if (i % 100 == 0) printStats();
        }
    }

    static void addEntries(int numEntries, boolean retain) {
        for (int i = 0; i < numEntries; i++) {
            ThreadLocal<Integer> tl = new ThreadLocal<>();
            tl.set(123);
            if (retain) retained.add(tl);
        }
    }

    static void printStats() {
        try {
            Field threadLocalsField = Thread.class.getDeclaredField("threadLocals");
            threadLocalsField.setAccessible(true);
            Field inheritableThreadLocalsField = Thread.class.getDeclaredField("inheritableThreadLocals");
            inheritableThreadLocalsField.setAccessible(true);
            Class<?> tlmClass = Class.forName("java.lang.ThreadLocal$ThreadLocalMap");
            Field sizeField = tlmClass.getDeclaredField("size");
            Field tableField = tlmClass.getDeclaredField("table");
            sizeField.setAccessible(true);
            tableField.setAccessible(true);
            Thread thread = Thread.currentThread();
            Object threadLocalMap = threadLocalsField.get(thread);
            int numEntries = (int) sizeField.get(threadLocalMap);
            Object[] table = (Object[]) tableField.get(threadLocalMap);
            Method expungeStaleEntry = tlmClass.getDeclaredMethod("expungeStaleEntry", int.class);
            expungeStaleEntry.setAccessible(true);

            int len = table.length;
            int maxRun = 0;
            int maxRunIndex = 0;
            int currentRun = 0;
            int checked = 0;

            for (int i = 0; checked < len; i = ((i + 1 < len) ? i + 1 : 0)) {
                if (table[i] == null) {
                    // end of run
                    if (currentRun > maxRun) {
                        maxRun = currentRun;
                        maxRunIndex = (i >= maxRun) ? (i - maxRun) : (i + len - maxRun);
                    }
                    currentRun = 0;
                } else {
                    currentRun++;
                }
                checked++;
            }

            // Remove the first element in the run, which should be the most expensive to remove
            long start = System.nanoTime();
            expungeStaleEntry.invoke(threadLocalMap, maxRunIndex);
            long dur = System.nanoTime() - start;

            System.out.printf("Length: %d, numEntries: %d, maxRun: %d, time to remove one element: %.3fms%n", len, numEntries, maxRun, dur / 1_000_000.0);
        } catch (Exception e) {
            throw new RuntimeException(e);
        }
    }
}
