<?php

declare(strict_types=1);

namespace Drupal\ai_guardrails\Plugin\AiGuardrail;

use Aws\BedrockRuntime\BedrockRuntimeClient;
use Aws\Credentials\CredentialProvider;
use Aws\Result;
use Drupal\ai\Attribute\AiGuardrail;
use Drupal\ai\Guardrail\AiGuardrailPluginBase;
use Drupal\ai\Guardrail\Result\GuardrailResultInterface;
use Drupal\ai\Guardrail\Result\PassResult;
use Drupal\ai\Guardrail\Result\RewriteInputResult;
use Drupal\ai\Guardrail\Result\RewriteOutputResult;
use Drupal\ai\Guardrail\Result\StopResult;
use Drupal\ai\OperationType\Chat\ChatInput;
use Drupal\ai\OperationType\Chat\ChatMessage;
use Drupal\ai\OperationType\Chat\ChatOutput;
use Drupal\Component\Plugin\ConfigurableInterface;
use Drupal\Core\Form\FormStateInterface;
use Drupal\Core\Plugin\PluginFormInterface;
use Drupal\Core\StringTranslation\StringTranslationTrait;
use Drupal\Core\StringTranslation\TranslatableMarkup;

/**
 * Plugin implementation of the AWS Bedrock guardrail.
 */
#[AiGuardrail(
  id: 'bedrock',
  label: new TranslatableMarkup('Aws Bedrock'),
  description: new TranslatableMarkup(
    'Calls AWS Bedrock to apply guardrails to the input or output text.'
  ),
)]
class AwsBedrock extends AiGuardrailPluginBase implements ConfigurableInterface, PluginFormInterface {

  use StringTranslationTrait;

  public const INPUT = 'INPUT';

  public const OUTPUT = 'OUTPUT';

  public const GUARDRAIL_INTERVENED = 'GUARDRAIL_INTERVENED';

  public function __construct(
    array $configuration,
    $plugin_id,
    $plugin_definition,
  ) {
    parent::__construct($configuration, $plugin_id, $plugin_definition);

    $this->setConfiguration($configuration);
  }

  /**
   * {@inheritdoc}
   */
  public function isAvailable(): bool {
    return class_exists('Aws\BedrockRuntime\BedrockRuntimeClient');
  }

  /**
   * {@inheritdoc}
   */
  public function processInput(ChatInput $input): GuardrailResultInterface {
    $messages = $input->getMessages();
    $last_message = end($messages);

    if (!$last_message instanceof ChatMessage) {
      return new PassResult('No text message found to analyze.');
    }

    $text = $last_message->getText();
    $result = $this->applyBedrockGuardrail($text, self::INPUT);

    if ($result->get('action') === self::GUARDRAIL_INTERVENED) {
      if ($this->configuration['mask_request_content']) {
        return new RewriteInputResult($result->get('outputs')[0]['text']);
      }
      else {
        return new StopResult($result->get('outputs')[0]['text']);
      }
    }
    else {
      return new PassResult('Guardrail passed.');
    }
  }

  /**
   * {@inheritdoc}
   */
  public function processOutput(ChatOutput $output): GuardrailResultInterface {
    $text = $output->getNormalized()->getText();
    $result = $this->applyBedrockGuardrail($text, self::OUTPUT);

    if ($result->get('action') === self::GUARDRAIL_INTERVENED) {
      $output_text = $result->get('outputs')[0]['text'];

      return new RewriteOutputResult($output_text);
    }
    else {
      return new PassResult('Guardrail passed.');
    }
  }

  /**
   * {@inheritdoc}
   */
  public function getConfiguration(): array {
    return [];
  }

  /**
   * {@inheritdoc}
   */
  public function setConfiguration(array $configuration): void {
    $this->configuration = $configuration;
  }

  /**
   * {@inheritdoc}
   */
  public function defaultConfiguration(): array {
    return [];
  }

  /**
   * {@inheritdoc}
   */
  public function buildConfigurationForm(
    array $form,
    FormStateInterface $form_state,
  ): array {
    $form['guardrail_identifier'] = [
      '#type' => 'textfield',
      '#title' => $this->t('Guardrail Identifier'),
      '#description' => $this->t('The identifier of the guardrail to apply.'),
      '#default_value' => $this->configuration['guardrail_identifier'] ?? '',
      '#required' => TRUE,
    ];

    $form['guardrail_version'] = [
      '#type' => 'number',
      '#title' => $this->t('Guardrail Version'),
      '#description' => $this->t('The version of the guardrail to apply.'),
      '#default_value' => $this->configuration['guardrail_version'] ?? '',
      '#required' => TRUE,
    ];

    $form['aws_region'] = [
      '#type' => 'textfield',
      '#title' => $this->t('AWS Region'),
      '#description' => $this->t(
        'The AWS region where the Bedrock service is hosted.'
      ),
      '#default_value' => $this->configuration['aws_region'] ?? '',
      '#required' => TRUE,
    ];

    $form['mask_request_content'] = [
      '#type' => 'checkbox',
      '#title' => $this->t('Mask Request Content'),
      '#description' => $this->t('With this option enabled, if the guardrail detects a PII or sensitive content in the input, it will mask the content instead of stopping the request.'),
      '#default_value' => $this->configuration['mask_request_content'] ?? FALSE,
    ];

    return $form;
  }

  /**
   * {@inheritdoc}
   */
  public function validateConfigurationForm(
    array &$form,
    FormStateInterface $form_state,
  ): void {}

  /**
   * {@inheritdoc}
   */
  public function submitConfigurationForm(
    array &$form,
    FormStateInterface $form_state,
  ): void {
    $values = $form_state->getValues();
    $this->setConfiguration($values);
  }

  /**
   * Calls AWS Bedrock to apply guardrails.
   *
   * @param string $text
   *   The text to analyze.
   * @param string $source
   *   The source of the text, either 'INPUT' or 'OUTPUT'.
   *
   * @return \Aws\Result
   *   The result from the Bedrock API call.
   */
  private function applyBedrockGuardrail(string $text, string $source): Result {
    $request = new BedrockRuntimeClient([
      'region' => $this->configuration['aws_region'],
      'version' => '2023-09-30',
      'credentials' => CredentialProvider::defaultProvider(),
    ]);

    return $request->applyGuardrail(
      [
        'guardrailIdentifier' => $this->configuration['guardrail_identifier'],
        'guardrailVersion' => $this->configuration['guardrail_version'],
        'source' => $source,
        'content' => [
          [
            'text' => [
              'text' => $text,
            ],
          ],
        ],
      ]
    );
  }

}
