import org.junit.Assert; 
import org.junit.Test; 

import java.util.*; 
public class TestHashSet {
	@Test 
    public void main() { 
        MagicForest magicForest = setup(5, new Edge(0, 1), new Edge(2, 3), new Edge(1, 4), new Edge(3, 4)); 

        int countTrees = magicForest.countTrees(); 

        Assert.assertEquals(1, countTrees); 
    } 

    private MagicForest setup(int nodes, Edge... edges) { 
        List<Edge> edgesList = new ArrayList<>(); 
        for (Edge edge : edges) { 
            edgesList.add(edge); 
        } 

        return new MagicForest(nodes, edgesList); 
    } 


    private class MagicForest { 
        private int nodes; 
        private List<Edge> edges; 

        public MagicForest(int nodes, List<Edge> edges) { 
            this.nodes = nodes; 
            this.edges = edges; 
        } 

        public int countTrees() { 
            Set<Integer> nodes = new HashSet<>(); 
            for (int i = 0; i < this.nodes; i++) { 
                nodes.add(i); 
            } 

            Set<Set<Integer>> forest = new HashSet<>(); 
            FindForest findForest = new FindForest(); 
            for (Edge edge : this.edges) { 
                Set<Integer> tree = findForest.find(forest, edge); 

                if (tree == null) { 
                    tree = new HashSet<>(); 
                    tree.add(edge.getEdgePU()); 
                    tree.add(edge.getEdgeDO()); 

                    forest.add(tree); 
                } else { 
                    tree.add(edge.getEdgePU()); 
                    tree.add(edge.getEdgeDO()); 
                } 

                nodes.remove(edge.getEdgePU()); 
                nodes.remove(edge.getEdgeDO()); 
            } 

            return nodes.size() + forest.size(); 
        } 
    } 

    private class FindForest { 
        public Set<Integer> find(Set<Set<Integer>> forest, Edge edge) { 
            Set<Integer> treePU = findTreeByNode(forest, edge.getEdgePU()); 
            Set<Integer> treeDO = findTreeByNode(forest, edge.getEdgeDO()); 

            if (treePU != null && treeDO != null) { 
                /* ------------------------------------------------------------------------------------------ 
                    We think that this remove are not working well 
                 */ 
                forest.remove(treeDO); 
                /* ------------------------------------------------------------------------------------------ */ 
                treePU.addAll(treeDO); 
                return treePU; 
            } else if (treePU != null) { 
                return treePU; 
            } else if (treeDO != null) { 
                return treeDO; 
            } 

            return null; 
        } 

        private Set<Integer> findTreeByNode(Set<Set<Integer>> forest, int numNode) { 
            Iterator<Set<Integer>> iterForest = forest.iterator(); 
            while (iterForest.hasNext()) { 
                Set<Integer> tree = iterForest.next(); 
                Iterator<Integer> nodes = tree.iterator(); 
                while (nodes.hasNext()) { 
                    Integer node = nodes.next(); 
                    if (node.equals(numNode)) { 
                        return tree; 
                    } 
                } 
            } 
            return null; 
        } 
    } 


    private class Edge { 

        private int edgePU; 
        private int edgeDO; 

        public Edge(int edgePU, int edgeDO) { 
            this.edgePU = edgePU; 
            this.edgeDO = edgeDO; 
        } 

        public boolean contains(int edge) { 
            if (edge == edgePU || edge == edgeDO) { 
                return true; 
            } 

            return false; 
        } 

        public int getEdgePU() { 
            return edgePU; 
        } 

        public int getEdgeDO() { 
            return edgeDO; 
        } 
    } 

}
