<?php

namespace Drupal\ai_provider_lmstudio\Plugin\AiProvider;

use Drupal\ai\Base\OpenAiBasedProviderClientBase;
use Drupal\Core\Messenger\MessengerInterface;
use Drupal\Core\Session\AccountProxyInterface;
use Drupal\Core\StringTranslation\StringTranslationTrait;
use Drupal\Core\StringTranslation\TranslatableMarkup;
use Drupal\ai\Attribute\AiProvider;
use Drupal\ai\Traits\OperationType\ChatTrait;
use Drupal\ai_provider_lmstudio\LmStudioControlApi;
use Symfony\Component\DependencyInjection\ContainerInterface;

/**
 * Plugin implementation of the 'lmstudio' provider.
 */
#[AiProvider(
  id: 'lmstudio',
  label: new TranslatableMarkup('LM Studio'),
)]
class LmStudioProvider extends OpenAiBasedProviderClientBase {

  use StringTranslationTrait;
  use ChatTrait;

  /**
   * The control API.
   *
   * @var \Drupal\ai_provider_lmstudio\LmStudioControlApi
   */
  protected LmStudioControlApi $controlApi;

  /**
   * Get the current user.
   *
   * @var \Drupal\Core\Session\AccountProxyInterface
   */
  protected AccountProxyInterface $currentUser;

  /**
   * The messenger service.
   *
   * @var \Drupal\Core\Messenger\MessengerInterface
   */
  protected MessengerInterface $messenger;

  /**
   * Dependency Injection for the LM Studio Control API.
   */
  public static function create(ContainerInterface $container, array $configuration, $plugin_id, $plugin_definition) {
    $instance = parent::create($container, $configuration, $plugin_id, $plugin_definition);
    $instance->controlApi = $container->get('ai_provider_lmstudio.control_api');
    $instance->controlApi->setConnectData($instance->getBaseHost());
    $instance->currentUser = $container->get('current_user');
    $instance->messenger = $container->get('messenger');
    return $instance;
  }

  /**
   * {@inheritdoc}
   */
  public function getConfiguredModels(?string $operation_type = NULL, array $capabilities = []): array {
    $this->loadClient();
    try {
      $response = $this->controlApi->getModels();
    }
    catch (\Exception $e) {
      if ($this->currentUser->hasPermission('administer ai providers')) {
        $this->messenger->addError($this->t('Failed to get models from LM Studio: @error', ['@error' => $e->getMessage()]));
      }
      $this->loggerFactory->get('ai_provider_lmstudio')->error('Failed to get models from LM Studio: @error', ['@error' => $e->getMessage()]);
      return [];
    }
    $models = [];
    if (isset($response['data'])) {
      foreach ($response['data'] as $model) {
        $models[$model['id']] = $model['id'];
      }
    }
    return $models;
  }

  /**
   * {@inheritdoc}
   */
  public function isUsable(?string $operation_type = NULL, array $capabilities = []): bool {
    if (!$this->getBaseHost()) {
      return FALSE;
    }
    // If its one of the bundles that LMStudio supports its usable.
    if ($operation_type) {
      return in_array($operation_type, $this->getSupportedOperationTypes());
    }
    return TRUE;
  }

  /**
   * {@inheritdoc}
   */
  public function getSupportedOperationTypes(): array {
    return [
      'chat',
      'embeddings',
    ];
  }

  /**
   * {@inheritdoc}
   */
  public function getModelSettings(string $model_id, array $generalConfig = []): array {
    return $generalConfig;
  }

  /**
   * Get control client.
   *
   * This is the client for controlling the LM Studio API.
   *
   * @return \Drupal\ai_provider_lmstudio\LmStudioControlApi
   *   The control client.
   */
  public function getControlClient(): LmStudioControlApi {
    return $this->controlApi;
  }

  /**
   * {@inheritdoc}
   */
  protected function loadClient(): void {
    if (empty($this->client)) {
      // Set custom endpoint from host config if available.
      if (!empty($this->getConfig()->get('host_name'))) {
        $this->setEndpoint($this->getConfig()->get('host_name') . '/v1');
      }
      if (!empty($this->getConfig()->get('port'))) {
        $this->setEndpoint($this->getConfig()->get('host_name') . ':' . $this->getConfig()->get('port') . '/v1');
      }
      $this->client = $this->createClient();
    }
  }

  /**
   * Gets the base host.
   *
   * @return string
   *   The base host.
   */
  protected function getBaseHost(): string {
    $host = rtrim($this->getConfig()->get('host_name'), '/');
    if ($this->getConfig()->get('port')) {
      $host .= ':' . $this->getConfig()->get('port');
    }
    return $host;
  }

}
