package uk.gov.ida.saml.metadata;

import java.security.KeyStore;
import java.security.cert.CertificateException;
import java.security.cert.X509Certificate;
import java.util.HashSet;
import javax.annotation.Nonnull;
import javax.annotation.Nullable;
import javax.validation.constraints.NotNull;
import javax.xml.namespace.QName;
import org.opensaml.core.xml.XMLObject;
import org.opensaml.saml.metadata.resolver.filter.MetadataFilter;
import org.opensaml.saml.saml2.metadata.EntitiesDescriptor;
import org.opensaml.saml.saml2.metadata.EntityDescriptor;
import org.opensaml.saml.saml2.metadata.RoleDescriptor;
import org.opensaml.xmlsec.keyinfo.KeyInfoSupport;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import uk.gov.ida.common.shared.security.verification.CertificateChainValidator;
import uk.gov.ida.saml.metadata.exception.CertificateConversionException;

/* loaded from: input_file:uk/gov/ida/saml/metadata/CertificateChainValidationFilter.class */
public final class CertificateChainValidationFilter implements MetadataFilter {
    private static final Logger LOG = LoggerFactory.getLogger(CertificateChainValidationFilter.class);
    private final QName role;
    private final CertificateChainValidator certificateChainValidator;
    private final KeyStore keyStore;

    public CertificateChainValidationFilter(@NotNull QName qName, @NotNull CertificateChainValidator certificateChainValidator, @NotNull KeyStore keyStore) {
        this.role = qName;
        this.certificateChainValidator = certificateChainValidator;
        this.keyStore = keyStore;
    }

    public QName getRole() {
        return this.role;
    }

    public CertificateChainValidator getCertificateChainValidator() {
        return this.certificateChainValidator;
    }

    private KeyStore getKeyStore() {
        return this.keyStore;
    }

    @Nullable
    public XMLObject filter(@Nullable XMLObject xMLObject) {
        if (xMLObject == null) {
            return null;
        }
        try {
            if (xMLObject instanceof EntityDescriptor) {
                EntityDescriptor entityDescriptor = (EntityDescriptor) xMLObject;
                filterOutUntrustedRoleDescriptors(entityDescriptor);
                if (entityDescriptor.getRoleDescriptors().isEmpty()) {
                    LOG.warn("EntityDescriptor '{}' has empty role descriptor list, metadata will be filtered out", entityDescriptor.getEntityID());
                    return null;
                }
            } else {
                if (!(xMLObject instanceof EntitiesDescriptor)) {
                    LOG.error("Internal error, metadata object was of an unsupported type: {}", xMLObject.getClass().getName());
                    return null;
                }
                EntitiesDescriptor entitiesDescriptor = (EntitiesDescriptor) xMLObject;
                filterOutUntrustedEntityDescriptors(entitiesDescriptor);
                if (entitiesDescriptor.getEntityDescriptors().isEmpty()) {
                    LOG.warn("EntitiesDescriptor '{}' has empty entity descriptor list, metadata will be filtered out");
                    return null;
                }
            }
            return xMLObject;
        } catch (CertificateConversionException e) {
            LOG.error("Saw fatal error validating certificate chain, metadata will be filtered out", e);
            return null;
        }
    }

    private void filterOutUntrustedEntityDescriptors(@Nonnull EntitiesDescriptor entitiesDescriptor) {
        LOG.trace("Processing EntitiesDescriptor group: {}", getGroupName(entitiesDescriptor));
        HashSet hashSet = new HashSet();
        entitiesDescriptor.getEntityDescriptors().forEach(entityDescriptor -> {
            filterOutUntrustedRoleDescriptors(entityDescriptor);
            if (entityDescriptor.getRoleDescriptors().isEmpty()) {
                LOG.warn("EntityDescriptor '{}' has empty role descriptor list, removing from metadata", entityDescriptor.getEntityID());
                hashSet.add(entityDescriptor);
            }
        });
        if (hashSet.isEmpty()) {
            return;
        }
        entitiesDescriptor.getEntityDescriptors().removeAll(hashSet);
        hashSet.clear();
    }

    private void filterOutUntrustedRoleDescriptors(@Nonnull EntityDescriptor entityDescriptor) {
        String entityID = entityDescriptor.getEntityID();
        LOG.trace("Processing EntityDescriptor: {}", entityID);
        entityDescriptor.getRoleDescriptors().removeIf(roleDescriptor -> {
            if (!getRole().equals(roleDescriptor.getElementQName())) {
                return false;
            }
            filterOutUntrustedKeyDescriptors(roleDescriptor);
            if (!roleDescriptor.getKeyDescriptors().isEmpty()) {
                return false;
            }
            LOG.warn("KeyDescriptor '{}' has empty key descriptor list, removing from metadata", entityID);
            return true;
        });
    }

    private void filterOutUntrustedKeyDescriptors(@Nonnull RoleDescriptor roleDescriptor) {
        roleDescriptor.getKeyDescriptors().removeIf(keyDescriptor -> {
            try {
                for (X509Certificate x509Certificate : KeyInfoSupport.getCertificates(keyDescriptor.getKeyInfo())) {
                    if (!getCertificateChainValidator().validate(x509Certificate, getKeyStore()).isValid()) {
                        LOG.warn("Certificate chain validation failed for metadata entry {}", x509Certificate.getSubjectDN());
                        return true;
                    }
                }
                return false;
            } catch (CertificateException e) {
                throw new CertificateConversionException(e);
            }
        });
    }

    private String getGroupName(EntitiesDescriptor entitiesDescriptor) {
        String name = entitiesDescriptor.getName();
        if (name != null) {
            return name;
        }
        String id = entitiesDescriptor.getID();
        return id != null ? id : "(unnamed)";
    }
}
