import java.io.ObjectInputFilter;
import java.util.function.BinaryOperator;

// example from JEP415
public class FilterInThread implements BinaryOperator<ObjectInputFilter> {
    private final ThreadLocal<ObjectInputFilter> filterThreadLocal =
            new ThreadLocal<>();

    // Construct a FilterInThread deserialization filter factory.
    public FilterInThread() {}

    // Returns a composite filter of the static JVM-wide filter, a thread-specific filter,
    // and the stream-specific filter.
    public ObjectInputFilter apply(ObjectInputFilter curr, ObjectInputFilter next) {
        if (curr == null) {
            // Called from the OIS constructor or perhaps OIS.setObjectInputFilter with no current filter
            var filter = filterThreadLocal.get();
            if (filter != null) {
                // Wrap the filter to reject UNDECIDED results
                filter = ObjectInputFilter.rejectUndecidedClass(filter);
            }
            if (next != null) {
                // Merge the next filter with the thread filter, if any
                // Initially this is the static JVM-wide filter passed from the OIS constructor
                // Wrap the filter to reject UNDECIDED results
                filter = ObjectInputFilter.merge(next, filter);
                filter = ObjectInputFilter.rejectUndecidedClass(filter);
            }
            return filter;
        } else {
            // Called from OIS.setObjectInputFilter with a current filter and a stream-specific filter.
            // The curr filter already incorporates the thread filter and static JVM-wide filter
            // and rejection of undecided classes
            // If there is a stream-specific filter wrap it and a filter to recheck for undecided
            if (next != null) {
                next = ObjectInputFilter.merge(next, curr);
                next = ObjectInputFilter.rejectUndecidedClass(next);
                return next;
            }
            return curr;
        }
    }

    // Applies the filter to the thread and invokes the runnable.
    public void doWithSerialFilter(ObjectInputFilter filter, Runnable runnable) {
        var prevFilter = filterThreadLocal.get();
        try {
            filterThreadLocal.set(filter);
            runnable.run();
        } finally {
            filterThreadLocal.set(prevFilter);
        }
    }
}
