<?php

namespace Drupal\ai_provider_aws_bedrock\Plugin\AiProvider;

use Aws\Bedrock\BedrockClient;
use Aws\BedrockRuntime\BedrockRuntimeClient;
use Drupal\ai_provider_aws_bedrock\BedrockModelPluginManager;
use Drupal\ai_provider_aws_bedrock\Models\BedrockModelInterface;
use Drupal\ai_provider_aws_bedrock\Models\ModelWithInputOutputInterface;
use Drupal\Component\Plugin\Exception\PluginNotFoundException;
use Drupal\Component\Serialization\Json;
use Drupal\Component\Utility\Crypt;
use Drupal\Core\Cache\CacheBackendInterface;
use Drupal\Core\Config\ImmutableConfig;
use Drupal\Core\Plugin\ContainerFactoryPluginInterface;
use Drupal\Core\StringTranslation\TranslatableMarkup;
use Drupal\ai\Attribute\AiProvider;
use Drupal\ai\Base\AiProviderClientBase;
use Drupal\ai\OperationType\Chat\ChatInput;
use Drupal\ai\OperationType\Chat\ChatInterface;
use Drupal\ai\OperationType\Chat\ChatMessage;
use Drupal\ai\OperationType\Chat\ChatOutput;
use Drupal\ai\OperationType\Chat\Tools\ToolsFunctionOutput;
use Drupal\ai\OperationType\Embeddings\EmbeddingsInput;
use Drupal\ai\OperationType\Embeddings\EmbeddingsInterface;
use Drupal\ai\OperationType\Embeddings\EmbeddingsOutput;
use Drupal\ai\OperationType\TextToImage\TextToImageInput;
use Drupal\ai\OperationType\TextToImage\TextToImageInterface;
use Drupal\ai\OperationType\TextToImage\TextToImageOutput;
use Drupal\ai_provider_aws_bedrock\BedrockChatMessageIterator;
use Drupal\ai_provider_aws_bedrock\Decorator\BedrockJsonSerializeDecorator;
use Symfony\Component\DependencyInjection\ContainerInterface;
use Symfony\Component\Yaml\Yaml;

/**
 * Plugin implementation of the 'bedrock' provider.
 */
#[AiProvider(
  id: 'bedrock',
  label: new TranslatableMarkup('AWS Bedrock'),
)]
class BedrockProvider extends AiProviderClientBase implements
  ContainerFactoryPluginInterface,
  ChatInterface,
  EmbeddingsInterface,
  TextToImageInterface {

  /**
   * The AWS Bedrock Runtime Client.
   *
   * @var \Aws\BedrockRuntime\BedrockRuntimeClient|null
   */
  protected $client;

  /**
   * The AWS Model Configuration client.
   *
   * @var \Aws\Bedrock\BedrockClient|null
   */
  protected $modelClient;

  /**
   * The AWS Bedrock factory.
   *
   * @var \Drupal\aws\AwsClientFactoryInterface
   */
  protected $clientFactory;

  /**
   * The entity type manager.
   *
   * @var \Drupal\Core\Entity\EntityTypeManagerInterface
   */
  protected $entityTypeManager;

  /**
   * The Bedrock Model plugin manager.
   */
  protected BedrockModelPluginManager $modelPluginManager;

  /**
   * Profile.
   *
   * @var string
   */
  protected string $profile = '';

  /**
   * Run moderation call, before a normal call.
   *
   * @var bool|null
   */
  protected bool|null $moderation = NULL;

  /**
   * {@inheritdoc}
   */
  public static function create(ContainerInterface $container, array $configuration, $plugin_id, $plugin_definition): static {
    $instance = parent::create($container, $configuration, $plugin_id, $plugin_definition);
    $instance->clientFactory = $container->get('aws.client_factory');
    $instance->entityTypeManager = $container->get('entity_type.manager');
    $instance->modelPluginManager = $container->get('ai_provider_aws_bedrock.models');
    return $instance;
  }

  /**
   * {@inheritdoc}
   */
  public function getConfiguredModels(?string $operation_type = NULL, $capabilities = []): array {
    $this->loadClient();
    return $this->getModels($operation_type, $capabilities);
  }

  /**
   * {@inheritdoc}
   */
  public function isUsable(?string $operation_type = NULL, $capabilities = []): bool {
    // If its not configured, it is not usable.
    if (!$this->getConfig()->get('profile')) {
      return FALSE;
    }
    // If its one of the bundles that AWS Bedrock supports its usable.
    if ($operation_type) {
      return in_array($operation_type, $this->getSupportedOperationTypes());
    }
    return TRUE;
  }

  /**
   * {@inheritdoc}
   */
  public function getSupportedOperationTypes(): array {
    return [
      'chat',
      'embeddings',
      'text_to_image',
    ];
  }

  /**
   * {@inheritdoc}
   */
  public function getConfig(): ImmutableConfig {
    return $this->configFactory->get('ai_provider_aws_bedrock.settings');
  }

  /**
   * {@inheritdoc}
   */
  public function getApiDefinition(): array {
    // Load the configuration.
    $definition = Yaml::parseFile($this->moduleHandler->getModule('ai_provider_aws_bedrock')->getPath() . '/definitions/api_defaults.yml');
    return $definition;
  }

  /**
   * {@inheritdoc}
   */
  public function getModelSettings(string $model_id, array $generalConfig = []): array {
    // We need to hardcode stuff here for now, until the API/SDK gives back.
    try {
      $plugin = $this->modelPluginManager->createInstance($model_id);
      \assert($plugin instanceof BedrockModelInterface);
    }
    catch (PluginNotFoundException) {
      // If we cannot find the model, return the general config.
      // A model is not required if the general config is enough
      // and no additional configuration is needed.
      return $generalConfig;
    }
    $plugin::providerConfig($generalConfig, $model_id);
    return $generalConfig;
  }

  /**
   * {@inheritdoc}
   */
  public function setAuthentication(mixed $authentication): void {
    // Set the new profile and reset the client.
    $this->profile = $authentication;
    $this->client = NULL;
    $this->modelClient = NULL;
  }

  /**
   * Enables moderation response, for all next coming responses.
   */
  public function enableModeration(): void {
    $this->moderation = TRUE;
  }

  /**
   * Disables moderation response, for all next coming responses.
   */
  public function disableModeration(): void {
    $this->moderation = FALSE;
  }

  /**
   * Gets the raw client.
   *
   * @param string $profile
   *   If the profile should be hot swapped.
   *
   * @return \Aws\BedrockRuntime\BedrockRuntimeClient
   *   The AWS Bedrock client.
   */
  public function getClient(string $profile = ''): BedrockRuntimeClient {
    // If the moderation is not set, we load it from the configuration.
    if (is_null($this->moderation)) {
      $this->moderation = $this->getConfig()->get('moderation');
    }
    if ($profile) {
      $this->setAuthentication($profile);
    }
    else {
      $this->setAuthentication($this->getDefaultProfile());
    }
    $this->loadClient();
    return $this->client;
  }

  /**
   * Get the raw model client.
   *
   * @param string $profile
   *   If the profile should be hot swapped.
   *
   * @return \Aws\Bedrock\BedrockClient
   *   The AWS Bedrock model client.
   */
  public function getModelClient(string $profile = ''): BedrockClient {
    if ($profile) {
      $this->setAuthentication($profile);
    }
    else {
      $this->setAuthentication($this->getDefaultProfile());
    }
    $this->loadClient();
    return $this->modelClient;
  }

  /**
   * Loads the AWS Bedrock Client with authentication if not initialized.
   */
  protected function loadClient(): void {
    if (!$this->client) {
      if (!$this->profile) {
        $this->setAuthentication($this->getDefaultProfile());
      }
      $this->modelClient = $this->clientFactory->setProfile($this->loadProfile())->getClient('bedrock');
      $this->client = $this->clientFactory->setProfile($this->loadProfile())->getClient('bedrockruntime');
    }
  }

  /**
   * Get the default profile.
   *
   * @return string
   *   The profile.
   */
  protected function getDefaultProfile(): string {
    return $this->getConfig()->get('profile');
  }

  /**
   * Load the profile.
   *
   * @return \Drupal\aws\Entity\ProfileInterface
   *   The AWS profile entity.
   */
  protected function loadProfile() {
    /** @var \Drupal\aws\Entity\ProfileInterface */
    return $this->entityTypeManager->getStorage('aws_profile')->load($this->profile);
  }

  /**
   * {@inheritdoc}
   */
  public function chat(array|string|ChatInput $input, string $model_id, array $tags = []): ChatOutput {
    $this->loadClient();
    // Normalize the input if needed.
    $chat_input = $input;
    $system_message = $this->chatSystemRole;
    if ($input instanceof ChatInput) {
      $chat_input = [];
      /** @var \Drupal\ai\OperationType\Chat\ChatMessage $message */
      foreach ($input->getMessages() as $message) {
        $text = trim((string) $message->getText());
        $content = [];

        // Add text only if it's non-empty.
        if ($text !== '') {
          $content[] = [
            'text' => $text,
          ];
        }
        if (count($message->getImages())) {
          foreach ($message->getImages() as $image) {
            $content[] = [
              'image' => [
                'format' => $image->getFileType() == 'jpg' ? 'jpeg' : $image->getFileType(),
                'source' => [
                  'bytes' => $image->getBinary(),
                ],
              ],
            ];
          }
        }
        // Get the tools that were used.
        if ($message->getTools()) {
          $tool_uses = $message->getRenderedTools();
          foreach ($tool_uses as $tool_use) {
            $content[] = [
              'toolUse' => [
                'toolUseId' => $tool_use['id'],
                'name' => $tool_use['function']['name'],
                // AWS wants the structured object, not the string.
                'input' => Json::decode($tool_use['function']['arguments']),
              ],
            ];
          }
        }

        // Add tool result content only for user role (Claude restriction).
        if (($message->getRole() === 'tool' || $message->getToolsId()) && $message->getRole() !== 'assistant') {
          $message->setRole('user');
          // Needs to set the content according to the AWS Spec.
          $content = [
            [
              'toolResult' => [
                'toolUseId' => $message->getToolsId(),
                'content' => [
                  [
                    // We need to set the text to the tool result, if empty.
                    'text' => $text !== '' ? $text : 'Tool Result',
                  ],
                ],
              ],
            ],
          ];
        }
        $chat_input[] = [
          'role' => $message->getRole(),
          'content' => $content,
        ];

      }
    }

    // Normalize the configuration.
    $this->normalizeConfiguration('chat', $model_id);

    $payload = [
      'modelId' => $model_id,
      'messages' => $chat_input,
      'inferenceConfig' => $this->configuration,
    ];

    // If we want to add tools to the input.
    if (method_exists($input, 'getChatTools') && $input->getChatTools()) {
      $tools = $input->getChatTools()->renderToolsArray();
      $aws_tools = [];
      foreach ($tools as $tool) {
        $aws_tool = $tool['function'];
        if (isset($tool['function']['parameters'])) {
          $aws_tool['inputSchema']['json'] = $tool['function']['parameters'];
        }
        else {
          $aws_tool['inputSchema']['json'] = [
            'type' => 'object',
          ];
        }
        unset($aws_tool['parameters']);
        $aws_tools[] = [
          'toolSpec' => $aws_tool,
        ];
      }
      $payload['toolConfig']['tools'] = $aws_tools;
    }

    // Set system message.
    if ($system_message) {
      $payload['system'] = [['text' => trim($system_message)]];
    }
    if ($this->streamed) {
      $response = $this->client->converseStream($payload);

      $message = new BedrockChatMessageIterator($response->get('stream'));
    }
    else {
      $response = $this->client->converse($payload);
      // Text messages is not always, given with tool results.
      $message = new ChatMessage($response['output']['message']['role'], $response['output']['message']['content'][0]['text'] ?? "");

      // Tool usage.
      if (isset($response['stopReason']) && $response['stopReason'] === 'tool_use') {
        $tools = [];
        foreach ($response['output']['message']['content'] as $tool) {
          if (isset($tool['toolUse'])) {
            $tools[] = new ToolsFunctionOutput($input->getChatTools()->getFunctionByName($tool['toolUse']['name']), $tool['toolUse']['toolUseId'], $tool['toolUse']['input']);
          }
        }
        $message->setTools($tools);
      }
    }

    $loggable_raw_response = new BedrockJsonSerializeDecorator($response);

    return new ChatOutput($message, $loggable_raw_response, $response['usage']);
  }

  /**
   * {@inheritdoc}
   */
  public function textToImage(string|TextToImageInput $input, string $model_id, array $tags = []): TextToImageOutput {
    $this->loadClient();
    // Normalize the input if needed.
    if ($input instanceof TextToImageInput) {
      $input = $input->getText();
    }

    try {
      $plugin = $this->modelPluginManager->createInstance($model_id);
      \assert($plugin instanceof ModelWithInputOutputInterface);
      $payload = $plugin::formatInput(input: $input, config: $this->configuration);
    }
    catch (PluginNotFoundException) {
      throw new \InvalidArgumentException('The model ' . $model_id . ' was not found.');
    }

    $response = $this->client->invokeModel([
      'modelId' => $model_id,
      'body' => json_encode($payload),
      'contentType' => 'application/json',
    ]);
    $body = json_decode($response['body'], TRUE);

    $images = $plugin::formatOutput($body, $this->configuration);
    return new TextToImageOutput($images, $response, []);
  }

  /**
   * {@inheritdoc}
   */
  public function embeddings(string|EmbeddingsInput $input, string $model_id, array $tags = []): EmbeddingsOutput {
    $this->loadClient();
    $plugin = NULL;
    // Normalize the input if needed.
    if ($input instanceof EmbeddingsInput) {
      $text = $input->getPrompt();
      $image = $input->getImage();

      try {
        $plugin = $this->modelPluginManager->createInstance($model_id);
        \assert($plugin instanceof ModelWithInputOutputInterface);
        $payload = $plugin::formatInput($text, $image, $this->configuration, $model_id);
      }
      catch (PluginNotFoundException) {
        throw new \InvalidArgumentException('The model ' . $model_id . ' was not found.');
      }
    }
    else {
      $payload = $input;
    }
    $response = $this->client->invokeModel([
      'modelId' => $model_id,
      'body' => json_encode($payload),
      'contentType' => 'application/json',
    ]);
    $body = json_decode($response['body'], TRUE);
    $embeddings = $plugin instanceof ModelWithInputOutputInterface ? $plugin::formatOutput($body, $this->configuration) : [];

    return new EmbeddingsOutput($embeddings, $body, []);
  }

  /**
   * {@inheritdoc}
   */
  public function maxEmbeddingsInput($model_id = ''): int {
    return 1024;
  }

  /**
   * Obtains a list of models from AWS Bedrock and caches the result.
   *
   * This method does its best job to filter out deprecated or unused models.
   * The AWS Bedrock API endpoint does not have a way to filter those out yet.
   *
   * @param string $operation_type
   *   The bundle to filter models by.
   * @param array $capabilities
   *   The capabilities to filter models by.
   *
   * @return array
   *   A filtered list of public models.
   */
  public function getModels(string $operation_type, $capabilities = []): array {
    $models = [];

    $cache_key = 'bedrock_models_' . $operation_type . '_' . Crypt::hashBase64(Json::encode($capabilities));
    $cache_data = $this->cacheBackend->get($cache_key);

    if (!empty($cache_data)) {
      return $cache_data->data;
    }

    // Modality output.
    $output_modality = 'TEXT';
    switch ($operation_type) {
      case 'text_to_image':
        $output_modality = 'IMAGE';
        break;

      case 'embeddings':
        $output_modality = 'EMBEDDING';
        break;
    }

    try {
      $list = $this->modelClient->listFoundationModels([
        'byOutputModality' => $output_modality,
      ]);

      foreach ($list['modelSummaries'] as $model) {
        // Only active models.
        if ($model['modelLifecycle']['status'] !== 'ACTIVE') {
          continue;
        }
        // If we should only show on demand.
        if ($this->getConfig()->get('on_demand') && !in_array('ON_DEMAND', $model['inferenceTypesSupported'])) {
          continue;
        }
        // If the capabilities are not empty, we filter by them.
        if (count($capabilities)) {
          // Go through all the models.
          $modelPlugin = $this->modelPluginManager->createInstanceFromModelId($model['modelId']);
          if ($modelPlugin instanceof BedrockModelInterface) {
            // If the model does not support the capabilities, we skip it.
            if ($modelPlugin::providerCapabilities($capabilities, $model['modelId']) === FALSE) {
              continue;
            }
          }
        }
        $models[$model['modelId']] = $model['modelName'] . ' (' . $model['modelId'] . ')';
      }
    }
    catch (\Exception) {
      // Do nothing, we will return an empty list or manual models.
    }

    // Also add possible manual models.
    foreach ([
      'chat',
      'embeddings',
      'text_to_image',
    ] as $manual_type) {
      if ($operation_type == $manual_type) {
        $manual_models = explode("\n", $this->getConfig()->get($manual_type . '_manual_models'));
        foreach ($manual_models as $model) {
          $model = trim($model);
          if (empty($model)) {
            continue;
          }
          $models[$model] = $model;
        }
      }
    }

    if (!empty($models)) {
      asort($models);
      $this->cacheBackend->set($cache_key, $models, CacheBackendInterface::CACHE_PERMANENT, ['aws_bedrock_models']);
    }

    return $models;
  }

}
