<?php

namespace Drupal\ai_spam_protection\Hook;

use Drupal\ai\AiProviderPluginManager;
use Drupal\ai\OperationType\Chat\ChatInput;
use Drupal\ai\OperationType\Chat\ChatMessage;
use Drupal\ai\Plugin\ProviderProxy;
use Drupal\Component\Render\MarkupInterface;
use Drupal\Component\Serialization\Yaml;
use Drupal\Core\Cache\Cache;
use Drupal\Core\Config\ConfigFactoryInterface;
use Drupal\Core\Config\ImmutableConfig;
use Drupal\Core\DependencyInjection\DependencySerializationTrait;
use Drupal\Core\Flood\FloodInterface;
use Drupal\Core\Form\FormStateInterface;
use Drupal\Core\Hook\Attribute\Hook;
use Drupal\Core\Path\PathMatcherInterface;
use Drupal\Core\Session\AccountProxyInterface;
use Drupal\Core\StringTranslation\StringTranslationTrait;
use Drupal\Core\StringTranslation\TranslationInterface;
use Psr\Log\LoggerAwareInterface;
use Psr\Log\LoggerAwareTrait;
use Symfony\Component\DependencyInjection\Attribute\Autowire;
use Symfony\Component\HttpFoundation\RequestStack;

class FormHooks implements LoggerAwareInterface {

  use StringTranslationTrait;
  use DependencySerializationTrait;
  use LoggerAwareTrait;

  protected ImmutableConfig $config;

  public function __construct(
    #[Autowire(service: 'ai.provider')]
    protected AiProviderPluginManager $aiProviderPluginManager,
    #[Autowire(service: 'string_translation')]
    TranslationInterface $stringTranslation,
    #[Autowire(service: 'config.factory')]
    protected ConfigFactoryInterface $configFactory,
    #[Autowire(service: 'request_stack')]
    protected RequestStack $requestStack,
    #[Autowire(service: 'path.matcher')]
    protected PathMatcherInterface $pathMatcher,
    #[Autowire(service: 'current_user')]
    protected AccountProxyInterface $currentUser,
    #[Autowire(service: 'flood')]
    protected FloodInterface $flood,
  ) {
    $this->setStringTranslation($stringTranslation);
    $this->config = $this->configFactory->get('ai_spam_protection.settings');
  }

  #[Hook('form_alter')]
  public function formAlter(&$form, FormStateInterface $form_state, $form_id) {
    $tags = $form['#cache']['tags'];
    $form['#cache']['tags'] = Cache::mergeTags($tags, $this->config->getCacheTags());
    $form['#cache']['contexts'][] = 'user.permissions';
    $form['#cache']['contexts'][] = 'ip';

    if ($this->byPassSpamProtection()) {
      return;
    }

    if ($this->formIsProtected($form_state, $form_id)) {
      $form['spam_confirm'] = [
        '#type' => 'checkbox',
        '#title' => $this->config->get('checkbox_label'),
        '#access' => FALSE,
        '#required' => FALSE,
        '#default_value' => $form_state->getUserInput()['spam_confirm'] ?? FALSE,
        '#element_validate' => [
          [$this, 'elementValidate']
        ],
      ];
    }
  }

  public function elementValidate(array &$element, FormStateInterface $form_state, array &$form) {

    // Check for any errors in the form state.
    if ($form_state->hasAnyErrors()) {
      return;
    }

    if ($this->config->get('human_interaction')) {
      $user_input = $form_state->getUserInput();
      $spam_confirm = isset($user_input['spam_confirm']) && $user_input['spam_confirm'];

      // If spam_confirm is checked, skip spam classification validation.
      if ($spam_confirm) {
        return;
      }
    }

    $providerAndModel = FALSE;
    if ($this->config->get('provider_model', NULL)) {
      $providerAndModel = $this->config->get('provider_model');
    }
    if (!$providerAndModel) {
      $providerAndModel = $this->aiProviderPluginManager->getDefaultProviderForOperationType('chat');
      if (!$providerAndModel) {
        $this->logger->warning($this->t('No custom provider/model configured, and no default "chat" model configured either.'));
        return;
      }
      $providerAndModel = $providerAndModel['provider_id'] . '__' . $providerAndModel['model_id'];
    }

    /** @var ProviderProxy $provider */
    $provider = $this->aiProviderPluginManager->loadProviderFromSimpleOption($providerAndModel);
    $model = $this->aiProviderPluginManager->getModelNameFromSimpleOption($providerAndModel);

    if (!$provider || !$model) {
        $this->logger->warning($this->t('The configured provider/model "@provider_model" could not be loaded. Please check your configuration.', ['@provider_model' => $providerAndModel]));
      return;
    }

    // The flood table's event name column has a 64 char limit. SHA-1 produces a
    // 40 character string. Added with the 'ai_spam_protection_' prefix, that
    // brings us to 59 characters. This safeguards us from very long form IDs.
    $hashed_form_id = 'ai_spam_protection_' . hash('sha1', $form_state->getValue('form_id'));
    if (!$this->flood->isAllowed($hashed_form_id, $this->config->get('flood_threshold'), $this->config->get('flood_window'))) {
      $form_state->setErrorByName('', $this->t('Too many spam attempts were logged by your IP. Please try again later.'));
    }

    $promptId = $this->config->get('prompt');
    $prompt = $this->configFactory->get('ai.ai_prompt.' . $promptId)->get('prompt');

    $messages = new ChatInput([
      new ChatMessage(
        'user',
        str_replace(
          '{formValues}',
          Yaml::encode($this->recursiveStringCast($form_state->getValues())),
          $prompt
        ),
      ),
    ]);

    $message = $provider->chat($messages, $model)->getNormalized();
    $response = trim($message->getText());
    if (strlen($response) === 1 && intval($response) === 1) {
      if ($this->config->get('flood_control')) {
        $this->flood->register($hashed_form_id, $this->config->get('flood_window'));
      }
      $message = $this->config->get('error_message');
      $form_state->setErrorByName('', $message);

      if ($this->config->get('human_interaction')) {
        $form['spam_confirm']['#access'] = TRUE;
        $form['spam_confirm']['#required'] = TRUE;
      }

      return;
    }

    if (strlen($response) !== 1 || $response !== '0') {
        $this->logger->warning($this->t('Unexpected response from AI. The response should be "0" for "No spam" or "1" for "Spam", but got "@response".', ['@response' => $response]));
    }
  }

  protected function recursiveStringCast(array $elements) {
    foreach ($elements as $key => &$value) {
      if (is_array($value)) {
        $value = $this->recursiveStringCast($value);
      }
      elseif ($value instanceof MarkupInterface) {
        $elements[$key] = (string) $value;
      }
    }

    return $elements;
  }

  public function byPassSpamProtection() {
    return $this->currentUser->hasPermission('bypass ai spam protection') || $this->clientIpIsWhiteListed();
  }

  public function clientIpIsWhiteListed() {
    $client_ip = $this->requestStack->getCurrentRequest()->getClientIp();
    $whitelist_patterns = $this->configFactory->get('ai_spam_protection.settings')->get('whitelist') ?: [];
    return $this->pathMatcher->matchPath($client_ip, implode(PHP_EOL, $whitelist_patterns));
  }

  public function formIsProtected(FormStateInterface $form_state, string $form_id) {
    // Never protect form configuration.
    if ($form_id === 'ai_spam_protection_config') {
      return FALSE;
    }
    if ($this->formIsSystemForm($form_id)) {
      return FALSE;
    }
    if ($this->config->get('protect_all')) {
      if ($this->formMatchPatterns($form_state, $form_id, implode(PHP_EOL, $this->config->get('unprotected_ids') ?: []))) {
        return FALSE;
      }
      return TRUE;
    }
    if ($this->formMatchPatterns($form_state, $form_id, implode(PHP_EOL, $this->config->get('protected_ids') ?: []))) {
      return TRUE;
    }
    return FALSE;
  }

  /**
   * {@inheritdoc}
   */
  public function formIsSystemForm(string $form_id) {
    // Check if the form is a system form. We don't want to protect them.
    // Theses forms may be programmatically submitted by drush and other
    // modules.
    if (preg_match('/[^a-zA-Z]system_/', $form_id) === 1 || preg_match('/[^a-zA-Z]search_/', $form_id) === 1 || preg_match('/[^a-zA-Z]views_exposed_form_/', $form_id) === 1) {
      return TRUE;
    }
    return FALSE;
  }

  /**
   * {@inheritdoc}
   */
  public function formMatchPatterns(FormStateInterface $form_state, string $form_id, string $patterns) {
    if (empty($patterns)) {
      return FALSE;
    }
    if ($this->pathMatcher->matchPath($form_id, $patterns)) {
      return TRUE;
    }
    // Check the base form ID.
    $base_form_id = $this->getBaseFormId($form_state);
    if (!empty($base_form_id) && $this->pathMatcher->matchPath($base_form_id, $patterns)) {
      return TRUE;
    }
    return FALSE;
  }

  /**
   * {@inheritdoc}
   */
  public function getBaseFormId(FormStateInterface $form_state) {
    $build_info = $form_state->getBuildInfo();
    return !empty($build_info['base_form_id']) ? $build_info['base_form_id'] : '';
  }
}
