<?php

namespace Drupal\ai_provider_litellm\Plugin\AiProvider;

use Drupal\Component\Serialization\Json;
use Drupal\Core\Cache\CacheBackendInterface;
use Drupal\Core\Config\ImmutableConfig;
use Drupal\Core\StringTranslation\TranslatableMarkup;
use Drupal\ai\Attribute\AiProvider;
use Drupal\ai\Exception\AiQuotaException;
use Drupal\ai\Exception\AiRateLimitException;
use Drupal\ai\OperationType\Chat\ChatInput;
use Drupal\ai\OperationType\Chat\ChatMessage;
use Drupal\ai\OperationType\Chat\ChatOutput;
use Drupal\ai\OperationType\Chat\Tools\ToolsFunctionOutput;
use Drupal\ai\OperationType\Embeddings\EmbeddingsInput;
use Drupal\ai_provider_litellm\LiteLLM\LiteLlmAiClient;
use Drupal\ai_provider_openai\OpenAiChatMessageIterator;
use Drupal\ai_provider_openai\Plugin\AiProvider\OpenAiProvider;
use Symfony\Component\DependencyInjection\ContainerInterface;

/**
 * Plugin implementation of the 'LiteLLM Proxy' provider.
 */
#[AiProvider(
  id: 'litellm',
  label: new TranslatableMarkup('LiteLLM Proxy'),
)]
class LiteLlmAiProvider extends OpenAiProvider {

  /**
   * The LiteLLM API client.
   *
   * @var \Drupal\ai_provider_litellm\LiteLLM\LiteLlmAiClient
   */
  protected LiteLlmAiClient $liteLlmClient;

  /**
   * The AI cache backend.
   *
   * @var \Drupal\Core\Cache\CacheBackendInterface
   */
  protected CacheBackendInterface $aiCache;

  /**
   * {@inheritdoc}
   */
  public static function create(ContainerInterface $container, array $configuration, $plugin_id, $plugin_definition) {
    $parent_instance = parent::create($container, $configuration, $plugin_id, $plugin_definition);
    $parent_instance->aiCache = $container->get('cache.ai');
    return $parent_instance;
  }

  /**
   * {@inheritdoc}
   */
  protected function loadClient(): void {
    parent::loadClient();
    $config = $this->getConfig();
    $this->liteLlmClient = new LiteLlmAiClient(
      $this->httpClient,
      $this->keyRepository,
      $config->get('host'),
      $config->get('api_key'),
    );
  }

  /**
   * {@inheritdoc}
   */
  public function getConfig(): ImmutableConfig {
    return $this->configFactory->get('ai_provider_litellm.settings');
  }

  /**
   * {@inheritdoc}
   */
  public function getApiDefinition(): array {
    return [];
  }

  /**
   * {@inheritdoc}
   */
  public function getModels(string $operation_type, $capabilities): array {
    $models = [];
    foreach ($this->liteLlmClient->models() as $model) {
      switch ($operation_type) {
        case 'text_to_image':
          if ($model->supportsImageOutput) {
            $models[$model->name] = $model->name;
          }
          break;

        case 'text_to_speech':
          if ($model->supportsAudioOutput) {
            $models[$model->name] = $model->name;
          }
          break;

        case 'audio_to_audio':
          if ($model->supportsAudioInput && $model->supportsAudioOutput) {
            $models[$model->name] = $model->name;
          }
          break;

        case 'moderation':
          if ($model->supportsModeration) {
            $models[$model->name] = $model->name;
          }
          break;

        case 'embeddings':
          if ($model->supportsEmbeddings) {
            $models[$model->name] = $model->name;
          }
          break;

        case 'chat':
          if ($model->supportsChat) {
            $models[$model->name] = $model->name;
          }
          break;

        case 'image_and_audio_to_video':
          if ($model->supportsImageAndAudioToVideo) {
            $models[$model->name] = $model->name;
          }
          break;

        default:
          break;
      }
    }
    return $models;
  }

  /**
   * {@inheritdoc}
   */
  public function getSetupData(): array {
    // Don't set up any default models.
    return [
      'key_config_name' => 'api_key',
    ];
  }

  /**
   * {@inheritdoc}
   */
  public function postSetup(): void {
    // Prevent the OpenAI rate limit check.
  }

  /**
   * {@inheritdoc}
   */
  public function chat(array|string|ChatInput $input, string $model_id, array $tags = []): ChatOutput {
    $this->loadClient();
    // Normalize the input if needed.
    $chat_input = $input;
    if ($input instanceof ChatInput) {
      $chat_input = [];
      // Add a system role if wanted.
      if ($this->chatSystemRole) {
        // If its o1 or o3 in it, we add it as a user message.
        if (preg_match('/(o1|o3)/i', $model_id)) {
          $chat_input[] = [
            'role' => 'user',
            'content' => $this->chatSystemRole,
          ];
        }
        else {
          $chat_input[] = [
            'role' => 'system',
            'content' => $this->chatSystemRole,
          ];
        }
      }
      /** @var \Drupal\ai\OperationType\Chat\ChatMessage $message */
      foreach ($input->getMessages() as $message) {
        $content = [
          [
            'type' => 'text',
            'text' => $message->getText(),
          ],
        ];
        if (count($message->getImages())) {
          foreach ($message->getImages() as $image) {
            $content[] = [
              'type' => 'image_url',
              'image_url' => [
                'url' => $image->getAsBase64EncodedString(),
              ],
            ];
          }
        }
        $new_message = [
          'role' => $message->getRole(),
          'content' => $content,
        ];

        // If its a tools response.
        if ($message->getToolsId()) {
          $new_message['tool_call_id'] = $message->getToolsId();
        }

        // If we want the results from some older tools call.
        if ($message->getTools()) {
          $new_message['tool_calls'] = $message->getRenderedTools();
        }

        $chat_input[] = $new_message;
      }
    }

    $payload = [
      'model' => $model_id,
      'messages' => $chat_input,
    ] + $this->configuration;
    // If we want to add tools to the input.
    if (is_object($input) && method_exists($input, 'getChatTools') && $input->getChatTools()) {
      $payload['tools'] = $input->getChatTools()->renderToolsArray();
      foreach ($payload['tools'] as $key => $tool) {
        $payload['tools'][$key]['function']['strict'] = FALSE;
      }
    }
    // Check for structured json schemas.
    if (is_object($input) && method_exists($input, 'getChatStructuredJsonSchema') && $input->getChatStructuredJsonSchema()) {
      $payload['response_format'] = [
        'type' => 'json_schema',
        'json_schema' => $input->getChatStructuredJsonSchema(),
      ];
    }
    try {
      if ($this->streamed) {
        $response = $this->client->chat()->createStreamed($payload);
        $message = new OpenAiChatMessageIterator($response);
      }
      else {
        $response = $this->client->chat()->create($payload)->toArray();
        // If tools are generated.
        $tools = [];
        if (!empty($response['choices'][0]['message']['tool_calls'])) {
          foreach ($response['choices'][0]['message']['tool_calls'] as $tool) {
            $arguments = Json::decode($tool['function']['arguments']);
            $tools[] = new ToolsFunctionOutput($input->getChatTools()->getFunctionByName($tool['function']['name']), $tool['id'], $arguments);
          }
        }
        $message = new ChatMessage($response['choices'][0]['message']['role'], $response['choices'][0]['message']['content'] ?? "", []);
        if (!empty($tools)) {
          $message->setTools($tools);
        }
      }
    }
    catch (\Exception $e) {
      // Try to figure out rate limit issues.
      if (strpos($e->getMessage(), 'Request too large') !== FALSE) {
        throw new AiRateLimitException($e->getMessage());
      }
      if (strpos($e->getMessage(), 'Too Many Requests') !== FALSE) {
        throw new AiRateLimitException($e->getMessage());
      }
      // Try to figure out quota issues.
      if (strpos($e->getMessage(), 'Budget has been exceeded!') !== FALSE) {
        throw new AiQuotaException($e->getMessage());
      }
      else {
        throw $e;
      }
    }

    return new ChatOutput($message, $response, []);
  }

  /**
   * {@inheritdoc}
   */
  public function embeddingsVectorSize(string $model_id): int {
    // Since we don't have the size, we need to calculate it.
    $cid = 'embeddings_size:' . $this->getPluginId() . ':' . $model_id;
    if ($cached = $this->aiCache->get($cid)) {
      return $cached->data;
    }

    // Just until all providers have the trait.
    if (!method_exists($this, 'embeddings')) {
      return 0;
    }
    // Normalize the input.
    $input = new EmbeddingsInput('Hello world!');
    $embedding = $this->embeddings($input, $model_id);
    try {
      $size = count($embedding->getNormalized());
    }
    catch (\Exception $e) {
      return 0;
    }
    $this->aiCache->set($cid, $size);

    return $size;
  }

  /**
   * {@inheritdoc}
   */
  public function getModelSettings(string $model_id, array $generalConfig = []): array {
    $this->loadClient();
    $model_info = $this->liteLlmClient->models()[$model_id] ?? NULL;

    if (!$model_info) {
      return $generalConfig;
    }

    foreach (array_keys($generalConfig) as $name) {
      if (!in_array($name, $model_info->supportedOpenAiParams)) {
        unset($generalConfig[$name]);
      }
    }

    return $generalConfig;
  }

}
