import java.util.*;
import java.util.concurrent.*;
import java.util.concurrent.atomic.AtomicLong; 

public class Test {
    public static void main(String[] args) throws Exception {
        final Map<Long, Long> map = new ConcurrentHashMap<>();
        final Long key1 = 0L;
        final Long key2 = (1L << 32) + 1;

        map.put(key1, 0L);
        map.put(key2, 0L);

        try(var svc = Executors.newFixedThreadPool(2)) {
            final AtomicLong round = new AtomicLong(0);
            long currentRound;
            while ((currentRound = round.getAndIncrement()) < 5000000) { // TODO what number makes sense here?
                final var ready = new CountDownLatch(1);
                final var done = new CountDownLatch(1);

                final var modifier = svc.submit(() -> {
                    try {
                        ready.await();
                        map.put(key1, map.remove(key1));
                        map.put(key2, map.remove(key2));
                    } catch (InterruptedException ignored) {
                    } finally {
                        done.countDown();
                    }
                });

                final var remover = svc.submit(() -> {
                    System.out.println("Starting removeIf, map: " + map);
                    map.entrySet().removeIf(entry -> {
                    try {
                        if (ready.getCount() > 0) { // do this once per round
                            ready.countDown();
                            done.await();
                        }
                        
                    } catch (InterruptedException ie) {
                        throw new AssertionError();
                    }
                    return false;
                });
                });
                try {
                    modifier.get(3, TimeUnit.SECONDS);
                    remover.get(3, TimeUnit.SECONDS);
                } catch (TimeoutException te) {
                    System.err.println("Likely detection of cycle at round: " + currentRound);
                    dumpAllThreads();
                    throw new AssertionError();
                }
            }
        }
    }

    private static void dumpAllThreads() {
        Map<Thread, StackTraceElement[]> traces = Thread.getAllStackTraces();
        for (Map.Entry<Thread, StackTraceElement[]> e : traces.entrySet()) {
            Thread t = e.getKey();
            System.err.println("\nThread: " + t.getName() + " (state: " + t.getState() + ")");
            for (StackTraceElement ste : e.getValue()) {
                System.err.println("    at " + ste);
            }
        }
    }
}
