package uk.co.spudsoft.jwtvalidatorvertx.impl;

import com.google.common.base.Strings;
import com.google.common.collect.ImmutableSet;
import io.vertx.core.Future;
import io.vertx.ext.auth.impl.jose.JWK;
import io.vertx.ext.auth.impl.jose.JWS;
import java.nio.charset.StandardCharsets;
import java.security.NoSuchAlgorithmException;
import java.time.Duration;
import java.time.Instant;
import java.time.LocalDateTime;
import java.time.ZoneOffset;
import java.util.Base64;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Objects;
import java.util.Set;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import uk.co.spudsoft.jwtvalidatorvertx.IssuerAcceptabilityHandler;
import uk.co.spudsoft.jwtvalidatorvertx.JsonWebKeySetHandler;
import uk.co.spudsoft.jwtvalidatorvertx.Jwt;
import uk.co.spudsoft.jwtvalidatorvertx.JwtValidator;

/* loaded from: input_file:uk/co/spudsoft/jwtvalidatorvertx/impl/JwtValidatorVertxImpl.class */
public class JwtValidatorVertxImpl implements JwtValidator {
    private static final Logger logger = LoggerFactory.getLogger(JwtValidatorVertxImpl.class);
    private static final Base64.Decoder B64DECODER = Base64.getUrlDecoder();
    private static final Set<String> DEFAULT_PERMITTED_ALGS = ImmutableSet.of("EdDSA", "ES256", "ES384", "ES512", "PS256", "PS384", new String[]{"PS512", "ES256K", "RS256", "RS384", "RS512"});
    private final JsonWebKeySetHandler jsonWebKeySetHandler;
    private final IssuerAcceptabilityHandler issuerAcceptabilityHandler;
    private boolean requireExp = true;
    private boolean requireNbf = true;
    private long timeLeewayMilliseconds = 0;
    private long minimumKeyCacheLifetime = 0;
    private Set<String> permittedAlgs = new HashSet(DEFAULT_PERMITTED_ALGS);

    public JwtValidatorVertxImpl(JsonWebKeySetHandler jsonWebKeySetHandler, IssuerAcceptabilityHandler issuerAcceptabilityHandler) {
        this.jsonWebKeySetHandler = jsonWebKeySetHandler;
        this.issuerAcceptabilityHandler = issuerAcceptabilityHandler;
    }

    @Override // uk.co.spudsoft.jwtvalidatorvertx.JwtValidator
    public Set<String> getPermittedAlgorithms() {
        return ImmutableSet.copyOf(this.permittedAlgs);
    }

    @Override // uk.co.spudsoft.jwtvalidatorvertx.JwtValidator
    public JwtValidator setPermittedAlgorithms(Set<String> set) throws NoSuchAlgorithmException {
        HashSet hashSet = new HashSet();
        for (String str : set) {
            if (!DEFAULT_PERMITTED_ALGS.contains(str)) {
                throw new NoSuchAlgorithmException();
            }
            hashSet.add(str);
        }
        this.permittedAlgs = hashSet;
        return this;
    }

    @Override // uk.co.spudsoft.jwtvalidatorvertx.JwtValidator
    public JwtValidator addPermittedAlgorithm(String str) throws NoSuchAlgorithmException {
        if (!DEFAULT_PERMITTED_ALGS.contains(str)) {
            throw new NoSuchAlgorithmException();
        }
        this.permittedAlgs.add(str);
        return this;
    }

    @Override // uk.co.spudsoft.jwtvalidatorvertx.JwtValidator
    public JwtValidator setTimeLeeway(Duration duration) {
        this.timeLeewayMilliseconds = duration.toMillis();
        return this;
    }

    @Override // uk.co.spudsoft.jwtvalidatorvertx.JwtValidator
    public JwtValidator setMinimumKeyCacheLifetime(Duration duration) {
        this.minimumKeyCacheLifetime = duration.toMillis();
        return this;
    }

    @Override // uk.co.spudsoft.jwtvalidatorvertx.JwtValidator
    public JwtValidator setRequireExp(boolean z) {
        this.requireExp = z;
        return this;
    }

    @Override // uk.co.spudsoft.jwtvalidatorvertx.JwtValidator
    public JwtValidator setRequireNbf(boolean z) {
        this.requireNbf = z;
        return this;
    }

    @Override // uk.co.spudsoft.jwtvalidatorvertx.JwtValidator
    public Future<Jwt> validateToken(String str, String str2, List<String> list, boolean z) {
        try {
            Jwt parseJws = Jwt.parseJws(str2);
            try {
                validateAlgorithm(parseJws.getAlgorithm());
                String kid = parseJws.getKid();
                if (parseJws.getPayloadSize() != 0) {
                    return this.jsonWebKeySetHandler.findJwk(str, kid).onFailure(th -> {
                        logger.warn("Failed to find JWK for {} ({}): ", new Object[]{kid, str, th});
                    }).compose(jwk -> {
                        try {
                            verify(jwk, parseJws);
                            long currentTimeMillis = System.currentTimeMillis();
                            validateIssuer(parseJws, str);
                            validateNbf(parseJws, currentTimeMillis);
                            validateExp(parseJws, currentTimeMillis);
                            validateAud(parseJws, list, z);
                            validateSub(parseJws);
                            return Future.succeededFuture(parseJws);
                        } catch (Throwable th2) {
                            logger.info("Validation of {} token failed: ", parseJws.getAlgorithm(), th2);
                            return Future.failedFuture(new IllegalArgumentException("Validation of " + parseJws.getAlgorithm() + " signed JWT failed", th2));
                        }
                    });
                }
                logger.error("No payload claims found in JWT");
                return Future.failedFuture(new IllegalArgumentException("Parse of signed JWT failed"));
            } catch (Throwable th2) {
                logger.error("Failed to process token: ", th2);
                return Future.failedFuture(th2);
            }
        } catch (Throwable th3) {
            if (logger.isTraceEnabled()) {
                logger.error("Parse of JWT ({}) failed: ", str2, th3);
            } else {
                logger.error("Parse of JWT failed: ", th3);
            }
            return Future.failedFuture(new IllegalArgumentException("Parse of signed JWT failed", th3));
        }
    }

    private void validateIssuer(Jwt jwt, String str) {
        String issuer = jwt.getIssuer();
        if (Strings.isNullOrEmpty(issuer)) {
            throw new IllegalStateException("No issuer in token.");
        }
        if (!this.issuerAcceptabilityHandler.isAcceptable(issuer)) {
            throw new IllegalStateException("Issuer from token (" + issuer + ") is not acceptable.");
        }
        if (str != null && !str.equals(issuer)) {
            throw new IllegalStateException("Issuer from token (" + issuer + ") does not match expected issuer (" + str + ").");
        }
    }

    private void verify(JWK jwk, Jwt jwt) throws IllegalArgumentException {
        if (Strings.isNullOrEmpty(jwt.getSignature())) {
            throw new IllegalStateException("No signature in token.");
        }
        Objects.requireNonNull(jwk, "JWK not set");
        if ("none".equals(jwk.getAlgorithm())) {
            throw new IllegalStateException("Algorithm \"none\" not allowed");
        }
        try {
            if (new JWS(jwk).verify(B64DECODER.decode(jwt.getSignature()), jwt.getSignatureBase().getBytes(StandardCharsets.UTF_8))) {
            } else {
                throw new IllegalArgumentException("Signature verification failed");
            }
        } catch (Throwable th) {
            logger.warn("Signature verification failed: ", th);
            throw new IllegalArgumentException("Signature verification failed", th);
        }
    }

    private void validateSub(Jwt jwt) throws IllegalArgumentException {
        if (Strings.isNullOrEmpty(jwt.getSubject())) {
            throw new IllegalArgumentException("No subject specified in token");
        }
    }

    private void validateAud(Jwt jwt, List<String> list, boolean z) throws IllegalArgumentException {
        if (list == null || (!z && list.isEmpty())) {
            throw new IllegalStateException("Required audience not set");
        }
        if (jwt.getAudience() == null) {
            throw new IllegalArgumentException("Token does not include aud claim");
        }
        for (String str : jwt.getAudience()) {
            Iterator<String> it = list.iterator();
            while (it.hasNext()) {
                if (it.next().equals(str)) {
                    return;
                }
            }
        }
        if (z) {
            return;
        }
        if (list.size() == 1) {
            logger.warn("Required audience ({}) not found in token aud claim: {}", list.get(0), jwt.getAudience());
        } else {
            logger.warn("None of the required audiences ({}) found in token aud claim: {}", list, jwt.getAudience());
        }
        throw new IllegalArgumentException("Required audience not found in token");
    }

    private void validateExp(Jwt jwt, long j) throws IllegalArgumentException {
        if (jwt.getExpiration() == null) {
            if (this.requireExp) {
                throw new IllegalArgumentException("Token does not specify exp");
            }
        } else {
            long j2 = j - this.timeLeewayMilliseconds;
            if (1000 * jwt.getExpiration().longValue() < j2) {
                logger.warn("Token exp = {} ({}), now = {} ({}), target = {} ({})", new Object[]{jwt.getExpiration(), jwt.getExpirationLocalDateTime(), Long.valueOf(j), LocalDateTime.ofInstant(Instant.ofEpochMilli(j), ZoneOffset.UTC), Long.valueOf(j2), LocalDateTime.ofInstant(Instant.ofEpochMilli(j2), ZoneOffset.UTC)});
                throw new IllegalArgumentException("Token is not valid after " + String.valueOf(jwt.getExpirationLocalDateTime()));
            }
        }
    }

    private void validateNbf(Jwt jwt, long j) throws IllegalArgumentException {
        if (jwt.getNotBefore() == null) {
            if (this.requireNbf) {
                throw new IllegalArgumentException("Token does not specify exp");
            }
        } else {
            long j2 = j + this.timeLeewayMilliseconds;
            if (1000 * jwt.getNotBefore().longValue() > j2) {
                logger.warn("Token nbf = {} ({}), now = {} ({}), target = {} ({})", new Object[]{jwt.getNotBefore(), jwt.getNotBeforeLocalDateTime(), Long.valueOf(j), LocalDateTime.ofInstant(Instant.ofEpochMilli(j), ZoneOffset.UTC), Long.valueOf(j2), LocalDateTime.ofInstant(Instant.ofEpochMilli(j2), ZoneOffset.UTC)});
                throw new IllegalArgumentException("Token is not valid until " + String.valueOf(jwt.getNotBeforeLocalDateTime()));
            }
        }
    }

    private void validateAlgorithm(String str) throws IllegalArgumentException {
        if (str == null) {
            logger.warn("No signature algorithm in token.");
            throw new IllegalArgumentException("Parse of signed JWT failed");
        }
        if (this.permittedAlgs.contains(str)) {
            return;
        }
        logger.warn("Failed to find algorithm \"{}\" in {}", str, this.permittedAlgs);
        throw new IllegalArgumentException("Parse of signed JWT failed");
    }
}
