import java.util.Arrays;
import java.util.Iterator;
import java.util.List;
import java.util.Map.Entry;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentMap;
import java.util.concurrent.ConcurrentSkipListMap;
import java.util.function.Supplier;

public class ConcurrentMapTest {
    private static void testEntrySet(Supplier<ConcurrentMap<String, String>> mapFactory) {
        ConcurrentMap<String, String> map = mapFactory.get();
        map.put("a", "A");
        Iterator<Entry<String, String>> iterator = map.entrySet().iterator();
        Entry<String, String> next = iterator.next();
        assert next.getKey().equals("a") && next.getValue().equals("A");
        // Replace mapping
        map.put("a", "B");
        // Should have no effect because Entry("a", "A") does not exist anymore
        iterator.remove();

        if (map.size() != 1) { // Fails
            throw new AssertionError("Wrong map size");
        }
    }

    private static void testValues(Supplier<ConcurrentMap<String, String>> mapFactory) {
        ConcurrentMap<String, String> map = mapFactory.get();
        map.put("a", "A");
        Iterator<String> iterator = map.values().iterator();
        String next = iterator.next();
        assert next.equals("A");
        // Replace mapping
        map.put("a", "B");
        // Should have no effect because value "A" does not exist anymore
        iterator.remove();

        if (map.size() != 1) { // Fails
            throw new AssertionError("Wrong map size");
        }
    }

    public static void main(String[] args) {
        List<Runnable> tests = Arrays.asList(
                () -> testEntrySet(ConcurrentHashMap::new),
                () -> testValues(ConcurrentHashMap::new),
                () -> testEntrySet(ConcurrentSkipListMap::new),
                () -> testValues(ConcurrentSkipListMap::new)
        );
        tests.forEach(test -> {
            try {
                test.run();
            } catch (AssertionError error) {
                error.printStackTrace();
            }
        });
    }
}