/*
 * Copyright Amazon.com Inc. or its affiliates. All Rights Reserved.
 * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
 *
 * This code is free software; you can redistribute it and/or modify it
 * under the terms of the GNU General Public License version 2 only, as
 * published by the Free Software Foundation.
 *
 * This code is distributed in the hope that it will be useful, but WITHOUT
 * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
 * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
 * version 2 for more details (a copy is included in the LICENSE file that
 * accompanied this code).
 *
 * You should have received a copy of the GNU General Public License version
 * 2 along with this work; if not, write to the Free Software Foundation,
 * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
 *
 * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
 * or visit www.oracle.com if you need additional information or have any
 * questions.
 *
 */

package org.openjdk.bench.vm.gc;

import java.lang.reflect.Method;
import java.util.*;

import jdk.test.lib.classloader.ClassLoadUtils;
import jdk.test.whitebox.WhiteBox;
import jdk.test.whitebox.code.NMethod;

public class BarrierCost {

    private static final int COMP_LEVEL = 1;

    private static final WhiteBox WB = WhiteBox.getWhiteBox();

    public static class A {
        public String s1;
        public String s2;
        public String s3;
        public String s4;
        public String s5;
        public String s6;
        public String s7;
        public String s8;
        public String s9;
    }

    public static A a = new A();

    public static class B {
        public static void test0() {
        }

        public static void test1() {
            a.s1 = a.s1 + "1";
        }

        public static void test2() {
            a.s1 = a.s1 + "1";
            a.s2 = a.s2 + "2";
        }

        public static void test3() {
            a.s1 = a.s1 + "1";
            a.s2 = a.s2 + "2";
            a.s3 = a.s3 + "3";
        }

        public static void test4() {
            a.s1 = a.s1 + "1";
            a.s2 = a.s2 + "2";
            a.s3 = a.s3 + "3";
            a.s4 = a.s4 + "4";
        }

        public static void test5() {
            a.s1 = a.s1 + "1";
            a.s2 = a.s2 + "2";
            a.s3 = a.s3 + "3";
            a.s4 = a.s4 + "4";
            a.s5 = a.s5 + "5";
        }

        public static void test6() {
            a.s1 = a.s1 + "1";
            a.s2 = a.s2 + "2";
            a.s3 = a.s3 + "3";
            a.s4 = a.s4 + "4";
            a.s5 = a.s5 + "5";
            a.s6 = a.s6 + "6";
        }

        public static void test7() {
            a.s1 = a.s1 + "1";
            a.s2 = a.s2 + "2";
            a.s3 = a.s3 + "3";
            a.s4 = a.s4 + "4";
            a.s5 = a.s5 + "5";
            a.s6 = a.s6 + "6";
            a.s7 = a.s7 + "7";
        }

        public static void test8() {
            a.s1 = a.s1 + "1";
            a.s2 = a.s2 + "2";
            a.s3 = a.s3 + "3";
            a.s4 = a.s4 + "4";
            a.s5 = a.s5 + "5";
            a.s6 = a.s6 + "6";
            a.s7 = a.s7 + "7";
            a.s8 = a.s8 + "8";
        }

        public static void test9() {
            a.s1 = a.s1 + "1";
            a.s2 = a.s2 + "2";
            a.s3 = a.s3 + "3";
            a.s4 = a.s4 + "4";
            a.s5 = a.s5 + "5";
            a.s6 = a.s6 + "6";
            a.s7 = a.s7 + "7";
            a.s8 = a.s8 + "8";
            a.s9 = a.s9 + "9";
        }
    }

    private static ClassLoader createClassLoaderFor(final String className, final byte[] code) {
        return new ClassLoader() {
            @Override
            public Class<?> loadClass(String name) throws ClassNotFoundException {
                if (!name.equals(className)) {
                    return super.loadClass(name);
                }

                return defineClass(name, code, 0, code.length);
            }
        };
    }

    private static TestMethod[] createTestMethods(String name, int count) throws Exception {
        final String className = B.class.getName();
        final String fileName = ClassLoadUtils.getClassPathFileName(className);
        if (fileName == null) {
            throw new ClassNotFoundException(className);
        }

        final byte[] code = ClassLoadUtils.readFile(fileName);

        TestMethod[] testMethods = new TestMethod[count];
        for (int i = 0; i < count; i++) {
            Class cl = createClassLoaderFor(className, code).loadClass(className);
            Method method = cl.getMethod(name);
            testMethods[i] = new TestMethod(method);
            testMethods[i].profile();
            testMethods[i].compileWithC2();
        }

        return testMethods;
    }

    private static final class TestMethod {
        private final Method method;

        public TestMethod(Method method) throws Exception {
            this.method = method;
            WB.testSetDontInlineMethod(method, true);
        }

        public void profile() throws Exception {
            method.invoke(null);
            WB.markMethodProfiled(method);
        }

        public void invoke() throws Exception {
            method.invoke(null);
        }

        public void compileWithC2() throws Exception {
            WB.enqueueMethodForCompilation(method, COMP_LEVEL);
            while (WB.isMethodQueuedForCompilation(method)) {
                Thread.onSpinWait();
            }
            if (WB.getMethodCompilationLevel(method) != COMP_LEVEL) {
                throw new IllegalStateException("Method " + method + " is not compiled at the compilation level: " + COMP_LEVEL + ". Got: " + WB.getMethodCompilationLevel(method));
            }
        }

        public NMethod getNMethod() {
            return NMethod.get(method, false);
        }
    }

    public static void main(String[] args) throws Exception {
        var start = System.currentTimeMillis();
        int methodCount = (args.length == 2) ? Integer.parseInt(args[1]) : 10000;
        TestMethod[] testMethods = createTestMethods(args[0], methodCount);
        System.out.println("Created test methods: " + (System.currentTimeMillis() - start) + " ms");
        System.out.println("PID: " + ProcessHandle.current().pid());
        System.out.println("Press any key to start young GC...");
        System.in.read();
        a = null;
        start = System.currentTimeMillis();
        WB.youngGC();
        System.out.println("Young GC: " + (System.currentTimeMillis() - start) + " ms");
    }
}
