<?php

namespace Drupal\govai\Plugin\AiProvider;

use Drupal\ai\Attribute\AiProvider;
use Drupal\ai\Base\OpenAiBasedProviderClientBase;
use Drupal\ai\Traits\OperationType\ChatTrait;
use Drupal\Core\Messenger\MessengerInterface;
use Drupal\Core\Session\AccountProxyInterface;
use Drupal\Core\StringTranslation\StringTranslationTrait;
use Drupal\Core\StringTranslation\TranslatableMarkup;
use Drupal\govai\GovAiControlApi;
use Symfony\Component\DependencyInjection\ContainerInterface;

/**
 * GovAI provider plugin.
 */
#[AiProvider(
  id: 'govai',
  label: new TranslatableMarkup('GovAI'),
)]
class GovAiProvider extends OpenAiBasedProviderClientBase {

  use StringTranslationTrait;
  use ChatTrait;

  /**
   * GovAI control API client.
   */
  protected GovAiControlApi $controlApi;

  /**
   * Current user service.
   */
  protected AccountProxyInterface $currentUser;

  /**
   * Messenger service.
   */
  protected MessengerInterface $messenger;

  /**
   * {@inheritdoc}
   */
  public static function create(ContainerInterface $container, array $configuration, $plugin_id, $plugin_definition) {
    $instance = parent::create($container, $configuration, $plugin_id, $plugin_definition);
    $instance->controlApi = $container->get('govai.control_api');
    $instance->controlApi->setConnectData($instance->getBaseHost());
    $instance->currentUser = $container->get('current_user');
    $instance->messenger = $container->get('messenger');
    return $instance;
  }

  /**
   * {@inheritdoc}
   */
  public function getConfiguredModels(?string $operation_type = NULL, array $capabilities = []): array {
    $this->loadClient();
    try {
      $response = $this->controlApi->getModels();
    }
    catch (\Exception $e) {
      if ($this->currentUser->hasPermission('administer ai providers')) {
        $this->messenger->addError($this->t('Unable to load GovAI models: @error', ['@error' => $e->getMessage()]));
      }
      $this->loggerFactory->get('govai')->error('Unable to load GovAI models: @error', ['@error' => $e->getMessage()]);
      return [];
    }

    $models = [];
    if (!empty($response['data'])) {
      foreach ($response['data'] as $model) {
        $models[$model['id']] = $model['id'];
      }
    }
    return $models;
  }

  /**
   * {@inheritdoc}
   */
  public function isUsable(?string $operation_type = NULL, array $capabilities = []): bool {
    if (!$this->getBaseHost()) {
      return FALSE;
    }
    if ($operation_type) {
      return in_array($operation_type, $this->getSupportedOperationTypes(), TRUE);
    }
    return TRUE;
  }

  /**
   * {@inheritdoc}
   */
  public function getSupportedOperationTypes(): array {
    return [
      'chat',
      'embeddings',
    ];
  }

  /**
   * {@inheritdoc}
   */
  public function getModelSettings(string $model_id, array $generalConfig = []): array {
    return $generalConfig;
  }

  /**
   * GovAI control client accessor.
   */
  public function getControlClient(): GovAiControlApi {
    return $this->controlApi;
  }

  /**
   * {@inheritdoc}
   */
  protected function loadClient(): void {
    if (empty($this->client)) {
      if (!empty($this->getConfig()->get('host_name'))) {
        $this->setEndpoint($this->getConfig()->get('host_name') . '/v1');
      }
      if (!empty($this->getConfig()->get('port'))) {
        $this->setEndpoint($this->getConfig()->get('host_name') . ':' . $this->getConfig()->get('port') . '/v1');
      }
      $this->client = $this->createClient();
    }
  }

  /**
   * Gets the configured base host.
   */
  protected function getBaseHost(): string {
    $host = rtrim((string) $this->getConfig()->get('host_name'), '/');
    if ($this->getConfig()->get('port')) {
      $host .= ':' . $this->getConfig()->get('port');
    }
    return $host;
  }

}
