import java.io.File;
import java.security.Key;
import java.security.KeyException;
import java.security.PublicKey;

import javax.xml.XMLConstants;
import javax.xml.crypto.AlgorithmMethod;
import javax.xml.crypto.KeySelector;
import javax.xml.crypto.KeySelectorException;
import javax.xml.crypto.KeySelectorResult;
import javax.xml.crypto.XMLCryptoContext;
import javax.xml.crypto.dsig.SignatureMethod;
import javax.xml.crypto.dsig.XMLSignature;
import javax.xml.crypto.dsig.XMLSignatureFactory;
import javax.xml.crypto.dsig.dom.DOMValidateContext;
import javax.xml.crypto.dsig.keyinfo.KeyInfo;
import javax.xml.crypto.dsig.keyinfo.KeyValue;
import javax.xml.parsers.DocumentBuilderFactory;

import org.w3c.dom.Attr;
import org.w3c.dom.Document;
import org.w3c.dom.Element;
import org.w3c.dom.Node;
import org.w3c.dom.NodeList;

public class TestSecure {
    private static final String DATA_DIR = "./";
    private static final File base = new File(DATA_DIR);
    private static final String[] ID_ATTRIBUTES = { "id", "Id", "xml:id" };
    private static final KeySelector keyValueKS = new KeyValueKeySelector();

    public static void main(String[] args) throws Exception {

	DocumentBuilderFactory dbf = DocumentBuilderFactory.newInstance();
	dbf.setNamespaceAware(true);
	dbf.setValidating(false);
	dbf.setFeature(XMLConstants.FEATURE_SECURE_PROCESSING, Boolean.TRUE);
	Document doc = dbf.newDocumentBuilder().parse(
		new File(base, "signature-wrapping.xml"));

	String tagName = "Signature";
	NodeList nl = doc.getElementsByTagNameNS(XMLSignature.XMLNS, tagName);
	Element element = (Element) nl.item(0);
	
	DOMValidateContext vc = new DOMValidateContext(keyValueKS, element);
	vc.setBaseURI(base.toURI().toString());

	vc.setProperty("org.jcp.xml.dsig.secureValidation", Boolean.FALSE);
	registerIdAttribute(vc, doc, "RetrievalMethod");
	
	System.out.println("Running without SecurityManager, use default "
		+ "org.jcp.xml.dsig.secureValidation : "
		+ vc.getProperty("org.jcp.xml.dsig.secureValidation"));

	XMLSignatureFactory factory = XMLSignatureFactory.getInstance();
	XMLSignature signature = factory.unmarshalXMLSignature(vc);
	signature.validate(vc);

    }

    static void registerIdAttribute(DOMValidateContext vc, Document doc,
	    String tagName) {
	NodeList nodes = doc.getElementsByTagNameNS("*", "RetrievalMethod");
	for (int i = 0; i < nodes.getLength(); i++) {
	    Node node = nodes.item(i);
	    if (!(node instanceof Element)) {
		continue;
	    }
	    Element element = (Element) node;
	    for (String attributeName : ID_ATTRIBUTES) {
		if (element.hasAttribute(attributeName)) {
		    Attr a = element.getAttributeNode(attributeName);
		    vc.setIdAttributeNS(element, a.getNamespaceURI(),
			    a.getLocalName());
		}
	    }
	}
    }

    private static class KeyValueKeySelector extends KeySelector {
	@Override
	public KeySelectorResult select(KeyInfo keyInfo,
		KeySelector.Purpose purpose, AlgorithmMethod method,
		XMLCryptoContext context) throws KeySelectorException {
	    if (keyInfo == null) {
		throw new KeySelectorException("Null KeyInfo object!");
	    }
	    SignatureMethod sm = (SignatureMethod) method;

	    for (Object xmlStructure : keyInfo.getContent()) {
		if (xmlStructure instanceof KeyValue) {
		    try {
			final PublicKey pk = ((KeyValue) xmlStructure)
				.getPublicKey();
			if (algEquals(sm.getAlgorithm(), pk.getAlgorithm())) {
			    return new KeySelectorResult() {
				@Override
				public Key getKey() {
				    return pk;
				}
			    };
			}
		    } catch (KeyException ke) {
			throw new KeySelectorException(ke);
		    }
		}
	    }
	    throw new KeySelectorException("No KeyValue element found!");
	}

	static boolean algEquals(String algURI, String algName) {
	    if (algName.equalsIgnoreCase("DSA")
		    && algURI.equals(SignatureMethod.DSA_SHA1)) {
		return true;
	    } else if (algName.equalsIgnoreCase("RSA")
		    && (algURI.equals(SignatureMethod.RSA_SHA1)
			    || algURI
				    .equals("http://www.w3.org/2001/04/xmldsig-more#rsa-sha256")
			    || algURI
				    .equals("http://www.w3.org/2001/04/xmldsig-more#rsa-sha384") || algURI
				.equals("http://www.w3.org/2001/04/xmldsig-more#rsa-sha512"))) {
		return true;
	    } else {
		return false;
	    }
	}
    }
}
