<?php

namespace Drupal\graphql_shield\Middleware;

use Drupal\graphql_shield\Service\QueryComplexityAnalyzer;
use Drupal\graphql_shield\Service\RateLimiter;
use Drupal\graphql_shield\Service\PersistedQueryManager;
use Drupal\graphql_shield\Service\AuthenticationManager;
use Drupal\graphql_shield\Service\IntrospectionController;
use Drupal\graphql_shield\Service\RequestValidator;
use Drupal\graphql_shield\Service\InputSanitizer;
use Drupal\graphql_shield\Service\SecurityLogger;
use Drupal\graphql_shield\Service\IpRestrictor;
use Drupal\graphql_shield\Service\DosProtector;
use Drupal\graphql_shield\Service\QueryAnalyzer;
use Drupal\graphql_shield\Service\ErrorHandler;
use Drupal\Core\Entity\EntityTypeManagerInterface;
use Symfony\Component\HttpFoundation\Request;
use Symfony\Component\HttpFoundation\Response;
use Symfony\Component\HttpFoundation\JsonResponse;
use Symfony\Component\HttpKernel\HttpKernelInterface;

/**
 * HTTP Middleware for GraphQL security.
 */
class GraphQLSecurityMiddleware implements HttpKernelInterface {

  /**
   * The wrapped HTTP kernel.
   *
   * @var \Symfony\Component\HttpKernel\HttpKernelInterface
   */
  protected $httpKernel;

  /**
   * The query complexity analyzer.
   *
   * @var \Drupal\graphql_shield\Service\QueryComplexityAnalyzer
   */
  protected $complexityAnalyzer;

  /**
   * The rate limiter.
   *
   * @var \Drupal\graphql_shield\Service\RateLimiter
   */
  protected $rateLimiter;

  /**
   * The persisted query manager.
   *
   * @var \Drupal\graphql_shield\Service\PersistedQueryManager
   */
  protected $persistedQueryManager;

  /**
   * The authentication manager.
   *
   * @var \Drupal\graphql_shield\Service\AuthenticationManager
   */
  protected $authManager;

  /**
   * The introspection controller.
   *
   * @var \Drupal\graphql_shield\Service\IntrospectionController
   */
  protected $introspectionController;

  /**
   * The request validator.
   *
   * @var \Drupal\graphql_shield\Service\RequestValidator
   */
  protected $requestValidator;

  /**
   * The input sanitizer.
   *
   * @var \Drupal\graphql_shield\Service\InputSanitizer
   */
  protected $inputSanitizer;

  /**
   * The security logger.
   *
   * @var \Drupal\graphql_shield\Service\SecurityLogger
   */
  protected $securityLogger;

  /**
   * The IP restrictor.
   *
   * @var \Drupal\graphql_shield\Service\IpRestrictor
   */
  protected $ipRestrictor;

  /**
   * The DoS protector.
   *
   * @var \Drupal\graphql_shield\Service\DosProtector
   */
  protected $dosProtector;

  /**
   * The query analyzer.
   *
   * @var \Drupal\graphql_shield\Service\QueryAnalyzer
   */
  protected $queryAnalyzer;

  /**
   * The error handler.
   *
   * @var \Drupal\graphql_shield\Service\ErrorHandler
   */
  protected $errorHandler;

  /**
   * The entity type manager.
   *
   * @var \Drupal\Core\Entity\EntityTypeManagerInterface
   */
  protected $entityTypeManager;

  /**
   * Cached GraphQL endpoint paths.
   *
   * @var array
   */
  protected $graphqlEndpoints;

  /**
   * Constructs a GraphQLSecurityMiddleware object.
   */
  public function __construct(
    HttpKernelInterface $http_kernel,
    QueryComplexityAnalyzer $complexity_analyzer,
    RateLimiter $rate_limiter,
    PersistedQueryManager $persisted_query_manager,
    AuthenticationManager $auth_manager,
    IntrospectionController $introspection_controller,
    RequestValidator $request_validator,
    InputSanitizer $input_sanitizer,
    SecurityLogger $security_logger,
    IpRestrictor $ip_restrictor,
    DosProtector $dos_protector,
    QueryAnalyzer $query_analyzer,
    ErrorHandler $error_handler,
    EntityTypeManagerInterface $entity_type_manager,
  ) {
    $this->httpKernel = $http_kernel;
    $this->complexityAnalyzer = $complexity_analyzer;
    $this->rateLimiter = $rate_limiter;
    $this->persistedQueryManager = $persisted_query_manager;
    $this->authManager = $auth_manager;
    $this->introspectionController = $introspection_controller;
    $this->requestValidator = $request_validator;
    $this->inputSanitizer = $input_sanitizer;
    $this->securityLogger = $security_logger;
    $this->ipRestrictor = $ip_restrictor;
    $this->dosProtector = $dos_protector;
    $this->queryAnalyzer = $query_analyzer;
    $this->errorHandler = $error_handler;
    $this->entityTypeManager = $entity_type_manager;
  }

  /**
   * {@inheritdoc}
   */
  public function handle(Request $request, $type = self::MAIN_REQUEST, $catch = TRUE): Response {
    // Only process GraphQL requests.
    if (!$this->isGraphqlRequest($request)) {
      return $this->httpKernel->handle($request, $type, $catch);
    }

    $start_time = microtime(TRUE);

    try {
      // Extract query data.
      $query_data = $this->extractQueryData($request);
      $query = $query_data['query'] ?? '';
      $variables = $query_data['variables'] ?? [];
      $operation_name = $query_data['operationName'] ?? NULL;
      $query_id = $query_data['queryId'] ?? NULL;

      // 1. Check IP restrictions.
      $ip_check = $this->ipRestrictor->isAllowed();
      if (!$ip_check['allowed']) {
        $this->securityLogger->logBlocked($ip_check['reason'], ['query' => $query]);
        return $this->createErrorResponse(
          $this->errorHandler->getSecurityErrorCode('ip_blocked'),
          $ip_check['reason'],
          403
        );
      }

      // 2. Check circuit breaker (DoS protection).
      $circuit = $this->dosProtector->checkCircuitBreaker();
      if ($circuit['open']) {
        return $this->createErrorResponse(
          'SERVICE_UNAVAILABLE',
          'Service temporarily unavailable',
          503,
          ['retry_after' => $circuit['reset_time'] - time()]
        );
      }

      // 3. API key authentication is now handled by
      // ApiKeyAuthenticationSubscriber (runs after Drupal's session
      // authentication at priority 250).
      // 4. Rate limiting.
      $uid = \Drupal::currentUser()->id();
      $ip = $request->getClientIp();
      $rate_limit = $this->rateLimiter->checkLimit($uid, $ip);

      if (!$rate_limit['allowed']) {
        $this->securityLogger->logBlocked('Rate limit exceeded');
        return $this->createErrorResponse(
          $this->errorHandler->getSecurityErrorCode('rate_limit'),
          'Rate limit exceeded',
          429,
          ['retry_after' => $rate_limit['retry_after']]
        );
      }

      // 5. Validate persisted queries.
      $persisted_check = $this->persistedQueryManager->validateQuery($query, $query_id);
      if (!$persisted_check['allowed']) {
        $this->securityLogger->logBlocked($persisted_check['error']);
        return $this->createErrorResponse(
          $this->errorHandler->getSecurityErrorCode('persisted_query'),
          $persisted_check['error'],
          400
        );
      }

      // Use persisted query if found.
      if (isset($persisted_check['query'])) {
        $query = $persisted_check['query'];
      }

      // 6. Validate request size and structure.
      $validation = $this->requestValidator->validate($query, $variables);
      if (!$validation['valid']) {
        $this->securityLogger->logBlocked($validation['error']);
        return $this->createErrorResponse(
          $this->errorHandler->getSecurityErrorCode('size'),
          $validation['error'],
          400
        );
      }

      // 7. Sanitize inputs.
      $variables = $this->inputSanitizer->sanitizeVariables($variables);

      // 8. Check introspection.
      $introspection_check = $this->introspectionController->validate($query);
      if (!$introspection_check['allowed']) {
        $this->securityLogger->logBlocked($introspection_check['error']);
        return $this->createErrorResponse(
          $this->errorHandler->getSecurityErrorCode('introspection'),
          $introspection_check['error'],
          403
        );
      }

      // 9. Analyze query complexity.
      $complexity = $this->complexityAnalyzer->analyze($query, $variables);
      if (!$complexity['allowed']) {
        $this->securityLogger->logBlocked('Query complexity exceeded', [
          'query' => $query,
          'complexity_score' => $complexity['complexity'],
        ]);
        return $this->createErrorResponse(
          $this->errorHandler->getSecurityErrorCode('complexity'),
          sprintf('Query exceeds complexity limit (score: %d, max: %d)',
            $complexity['complexity'],
            $complexity['max_complexity']
          ),
          400
        );
      }

      // 10. Check for DoS patterns.
      $dos_check = $this->dosProtector->checkThreat($query);
      if ($dos_check['threat_detected']) {
        $this->securityLogger->logBlocked($dos_check['reason']);
        $this->dosProtector->autoBlock($ip, $dos_check['reason']);
        return $this->createErrorResponse(
          'DOS_THREAT_DETECTED',
          'Request blocked due to suspicious patterns',
          403
        );
      }

      // 11. Advanced query analysis.
      $analysis = $this->queryAnalyzer->analyze($query);
      if (!$analysis['safe']) {
        $this->securityLogger->logBlocked('Suspicious query patterns', [
          'query' => $query,
          'issues' => $analysis['issues'],
        ]);
        return $this->createErrorResponse(
          'QUERY_BLOCKED',
          'Query contains suspicious patterns',
          400
        );
      }

      // All checks passed - increment connection counter.
      $this->dosProtector->incrementConnections();

      // Process the request.
      $response = $this->httpKernel->handle($request, $type, $catch);

      // Decrement connection counter.
      $this->dosProtector->decrementConnections();

      // Log successful query.
      $execution_time = microtime(TRUE) - $start_time;
      $this->securityLogger->logQuery(
        $query,
        $variables,
        $execution_time,
        $complexity['complexity']
      );

      // Add rate limit headers.
      $response->headers->set('X-RateLimit-Limit', $rate_limit['limit'] ?? 100);
      $response->headers->set('X-RateLimit-Remaining', $rate_limit['remaining'] ?? 99);

      return $response;
    }
    catch (\Exception $e) {
      $this->securityLogger->logError($e->getMessage(), [
        'query' => $query ?? '',
        'trace' => $e->getTraceAsString(),
      ]);

      return $this->createErrorResponse(
        'INTERNAL_ERROR',
        'An error occurred processing your request',
        500
      );
    }
  }

  /**
   * Checks if request is a GraphQL request.
   *
   * @param \Symfony\Component\HttpFoundation\Request $request
   *   The request.
   *
   * @return bool
   *   TRUE if GraphQL request.
   */
  protected function isGraphqlRequest(Request $request) {
    $path = $request->getPathInfo();

    // Exclude admin routes.
    if (str_starts_with($path, '/admin/')) {
      return FALSE;
    }

    // Get configured GraphQL endpoint paths.
    $endpoints = $this->getGraphqlEndpoints();

    // Check if current path matches any GraphQL endpoint.
    foreach ($endpoints as $endpoint) {
      if ($path === $endpoint || str_starts_with($path, $endpoint . '/')) {
        return TRUE;
      }
    }

    return FALSE;
  }

  /**
   * Gets all configured GraphQL endpoint paths.
   *
   * @return array
   *   Array of GraphQL endpoint paths.
   */
  protected function getGraphqlEndpoints() {
    // Return cached endpoints if available.
    if (isset($this->graphqlEndpoints)) {
      return $this->graphqlEndpoints;
    }

    $this->graphqlEndpoints = [];

    try {
      // Load all GraphQL server entities.
      $storage = $this->entityTypeManager->getStorage('graphql_server');
      $servers = $storage->loadMultiple();

      /** @var \Drupal\graphql\Entity\ServerInterface $server */
      foreach ($servers as $server) {
        // Get the endpoint path from the server.
        // The endpoint is a public property on the server entity.
        $endpoint = $server->endpoint;
        if (!empty($endpoint)) {
          // Ensure endpoint starts with /.
          if (!str_starts_with($endpoint, '/')) {
            $endpoint = '/' . $endpoint;
          }
          $this->graphqlEndpoints[] = $endpoint;
        }
      }
    }
    catch (\Exception $e) {
      // If GraphQL module is not installed or servers can't be loaded,
      // fall back to common GraphQL endpoint patterns.
      $this->graphqlEndpoints = ['/graphql'];
    }

    // If no endpoints found, use default.
    if (empty($this->graphqlEndpoints)) {
      $this->graphqlEndpoints = ['/graphql'];
    }

    return $this->graphqlEndpoints;
  }

  /**
   * Extracts query data from request.
   *
   * @param \Symfony\Component\HttpFoundation\Request $request
   *   The request.
   *
   * @return array
   *   Query data.
   */
  protected function extractQueryData(Request $request) {
    if ($request->getMethod() === 'GET') {
      return [
        'query' => $request->query->get('query', ''),
        'variables' => json_decode($request->query->get('variables', '{}'), TRUE) ?: [],
        'operationName' => $request->query->get('operationName'),
        'queryId' => $request->query->get('queryId'),
      ];
    }

    $content = $request->getContent();
    $data = json_decode($content, TRUE) ?: [];

    return [
      'query' => $data['query'] ?? '',
      'variables' => $data['variables'] ?? [],
      'operationName' => $data['operationName'] ?? NULL,
      'queryId' => $data['queryId'] ?? NULL,
    ];
  }

  /**
   * Creates an error response.
   *
   * @param string $code
   *   Error code.
   * @param string $message
   *   Error message.
   * @param int $status
   *   HTTP status code.
   * @param array $extra
   *   Extra data.
   *
   * @return \Symfony\Component\HttpFoundation\JsonResponse
   *   Error response.
   */
  protected function createErrorResponse($code, $message, $status = 400, array $extra = []) {
    $error = $this->errorHandler->createError($code, $message, $extra);

    return new JsonResponse([
      'errors' => [$error],
    ], $status);
  }

}
