<?php

namespace Drupal\openid_client_advanced\Service;

use Firebase\JWT\JWK;
use Firebase\JWT\JWT;
use Firebase\JWT\Key;
use Firebase\JWT\SignatureInvalidException;

/**
 * Validates JWT signatures using configured public keys.
 */
class JwtSignatureValidator {

  /**
   * Validate a JWT signature and return the decoded payload.
   *
   * @param string $jwt
   *   The encoded JSON Web Token.
   * @param array $public_keys
   *   A list of PEM encoded public keys or JWKS JSON documents.
   * @param array $allowed_algorithms
   *   Algorithms that are accepted when verifying the signature.
   *
   * @return array
   *   The decoded JWT payload.
   *
   * @throws \UnexpectedValueException
   *   Thrown when the token is malformed.
   * @throws \Firebase\JWT\SignatureInvalidException
   *   Thrown when the signature could not be validated.
   */
  public function validate(string $jwt, array $public_keys, array $allowed_algorithms = ['RS256']): array {
    if (empty($public_keys)) {
      throw new \UnexpectedValueException('No public keys available for signature validation.');
    }

    // Inspect the JWT header to learn which algorithm and key id were used.
    $header = $this->decodeHeader($jwt);
    $algorithm = $header['alg'] ?? NULL;
    if (empty($algorithm)) {
      throw new \UnexpectedValueException('The JWT is missing the "alg" header value.');
    }

    if (!in_array($algorithm, $allowed_algorithms, TRUE)) {
      throw new SignatureInvalidException(sprintf('The token signature algorithm "%s" is not allowed.', $algorithm));
    }

    $candidate_keys = $this->prepareKeys($public_keys, $algorithm, $header['kid'] ?? NULL);
    if (empty($candidate_keys)) {
      throw new SignatureInvalidException('No usable keys are available for signature validation.');
    }

    $last_exception = NULL;
    foreach ($candidate_keys as $key) {
      try {
        // Decode with the current key; success means verification passed.
        $decoded = JWT::decode($jwt, $key);
        return (array) $decoded;
      }
      catch (\Throwable $exception) {
        // Remember the failure but keep trying remaining keys.
        $last_exception = $exception;
        continue;
      }
    }

    if ($last_exception instanceof \Throwable) {
      throw $last_exception;
    }

    throw new SignatureInvalidException('Unable to validate the JWT signature with the configured public keys.');
  }

  /**
   * Build the list of keys that should be attempted for validation.
   *
   * @param array $public_keys
   *   Raw key values from configuration.
   * @param string $algorithm
   *   The algorithm from the JWT header.
   * @param string|null $kid
   *   The key identifier from the JWT header.
   *
   * @return \Firebase\JWT\Key[]
   *   A prioritized list of Key instances.
   */
  protected function prepareKeys(array $public_keys, string $algorithm, ?string $kid = NULL): array {
    $preferred = [];
    $fallback = [];

    foreach ($public_keys as $public_key) {
      if ($public_key instanceof Key) {
        $fallback[] = $public_key;
        continue;
      }

      if (!is_string($public_key)) {
        continue;
      }

      $trimmed = trim($public_key);
      if ($trimmed === '') {
        continue;
      }

      if ($this->looksLikeJsonDocument($trimmed)) {
        // Treat JSON strings as JWKS documents that can yield multiple keys.
        $jwks = $this->decodeJwks($trimmed, $algorithm);
        try {
          $parsed_keys = JWK::parseKeySet($jwks, $algorithm);
        }
        catch (\Throwable $exception) {
          throw new SignatureInvalidException(
            sprintf('Unable to parse JWKS document: %s', $exception->getMessage()),
            0,
            $exception
          );
        }

        if ($kid !== NULL && isset($parsed_keys[$kid]) && $parsed_keys[$kid] instanceof Key) {
          // Kid matches the header, so try this key before any others.
          $preferred[] = $parsed_keys[$kid];
          unset($parsed_keys[$kid]);
        }

        foreach ($parsed_keys as $parsed_key) {
          if ($parsed_key instanceof Key) {
            $fallback[] = $parsed_key;
          }
        }

        continue;
      }

      // Plain string keys are interpreted as PEM material for the algorithm.
      $fallback[] = new Key($trimmed, $algorithm);
    }

    return array_merge($preferred, $fallback);
  }

  /**
   * Decode and filter a JWKS payload.
   *
   * @param string $json
   *   The JWKS JSON string.
   * @param string $algorithm
   *   The expected algorithm.
   *
   * @return array
   *   A JWKS array containing only usable keys.
   */
  protected function decodeJwks(string $json, string $algorithm): array {
    try {
      $decoded = json_decode($json, TRUE, 512, JSON_THROW_ON_ERROR);
    }
    catch (\JsonException $exception) {
      throw new \UnexpectedValueException(
        sprintf('Unable to parse JWKS JSON: %s', $exception->getMessage()),
        0,
        $exception
      );
    }

    if (isset($decoded['keys']) && is_array($decoded['keys'])) {
      $keys = $decoded['keys'];
    }
    elseif (is_array($decoded) && array_is_list($decoded)) {
      $keys = $decoded;
    }
    elseif (is_array($decoded) && isset($decoded['kty'])) {
      $keys = [$decoded];
    }
    else {
      throw new \UnexpectedValueException('The JWKS JSON does not contain a "keys" array or a single key definition.');
    }

    $filtered = [];
    foreach ($keys as $key) {
      if (!is_array($key)) {
        continue;
      }
      if (isset($key['use']) && $key['use'] !== 'sig') {
        continue;
      }
      if (isset($key['alg']) && strcasecmp((string) $key['alg'], $algorithm) !== 0) {
        continue;
      }
      // Keep only keys that can sign with the expected algorithm.
      $filtered[] = $key;
    }

    if (empty($filtered)) {
      throw new \UnexpectedValueException('No matching signature keys found in the JWKS JSON.');
    }

    return ['keys' => array_values($filtered)];
  }

  /**
   * Determine whether a string is likely JSON content.
   *
   * @param string $value
   *   The value to test.
   *
   * @return bool
   *   TRUE if the value appears to be JSON, FALSE otherwise.
   */
  protected function looksLikeJsonDocument(string $value): bool {
    $value = ltrim($value);
    return $value !== '' && ($value[0] === '{' || $value[0] === '[');
  }

  /**
   * Decode the JWT header into an associative array.
   *
   * @param string $jwt
   *   The encoded JWT.
   *
   * @return array
   *   The decoded header.
   */
  protected function decodeHeader(string $jwt): array {
    $parts = explode('.', $jwt, 3);
    if (count($parts) !== 3) {
      throw new \UnexpectedValueException('Malformed JWT: expected three segments.');
    }

    // Segment 0 holds the header and uses base64url encoding.
    $decoded_json = $this->decodeSegment($parts[0]);
    $header = json_decode($decoded_json, TRUE);
    if (!is_array($header)) {
      throw new \UnexpectedValueException('Unable to decode the JWT header.');
    }

    return $header;
  }

  /**
   * Decode a JWT segment using URL safe base64 rules.
   *
   * @param string $segment
   *   The JWT segment.
   *
   * @return string
   *   The decoded segment value.
   */
  protected function decodeSegment(string $segment): string {
    $remainder = strlen($segment) % 4;
    if ($remainder) {
      // Add padding because base64url may omit it,
      // but PHP's decoder requires it.
      $segment .= str_repeat('=', 4 - $remainder);
    }

    $decoded = base64_decode(strtr($segment, '-_', '+/'));
    if ($decoded === FALSE) {
      throw new \UnexpectedValueException('Failed to base64 decode JWT segment.');
    }

    return $decoded;
  }

}
