<?php

namespace Drupal\ai_more_like_this;

use Drupal\ai\AiVdbProviderPluginManager;
use Drupal\ai\Enum\VdbSimilarityMetrics;
use Drupal\ai_vdb_provider_postgres\Exception\DatabaseConnectionException;
use Drupal\ai_vdb_provider_postgres\Exception\VectorSearchException;
use Drupal\Core\Entity\EntityTypeManagerInterface;
use Drupal\search_api\SearchApiException;
use Drupal\search_api\IndexInterface;
use PgSql\Connection;

/**
 * Fetches similar nodes' IDs from a Postgres pgvector-based VDB.
 *
 * The embeddings are assumed to live in a table (configurable) with:
 *  - integer ID column referencing node NID (default: nid)
 *  - pgvector column with the embedding (default: embedding)
 */
class MoreLikeThisService {
  /**
   * Search API Index.
   *
   * @var \Drupal\search_api\IndexInterface
   */
  protected IndexInterface $index;
  /**
   * Machine name of Node ID field in the Index.
   *
   * @var string
   */
  protected string $nidFieldName = '';
  /**
   * Search API backend config.
   *
   * @var array
   */
  protected array $backendConfig = [];
  /**
   * VDB similarity metric.
   *
   * @var string
   */
  protected string $metric = '';
  /**
   * The Embedding Strategy.
   *
   * @var string
   */
  protected string $embeddingStrategy = '';
  /**
   * The similarity distance threshold.
   *
   * @var float
   */
  protected float $distanceThreshold;
  /**
   * VDB Connection.
   *
   * @var \PgSql\Connection
   */
  protected Connection $connection;

  /**
   * MoreLikeThisService constructor.
   */
  public function __construct(
    private readonly AiVdbProviderPluginManager $aiVdbProviderPluginManager,
    private readonly EntityTypeManagerInterface $entityTypeManager,
  ) {}

  /**
   * Initializes all member variables.
   *
   * @throws \Drupal\ai_vdb_provider_postgres\Exception\DatabaseConnectionException
   * @throws \Drupal\search_api\SearchApiException|\Drupal\Component\Plugin\Exception\PluginException
   */
  public function init(array $options): array {
    $init_errors = [];
    if (empty($this->index) || $this->index->id() !== $options['index']) {
      $rag_storage = $this->entityTypeManager->getStorage('search_api_index');
      /** @var \Drupal\search_api\Entity\Index $index */
      $index = $rag_storage->load($options['index']);
      if (!$index) {
        $init_errors[] = 'AI Search Index not found.';
      }
      else {
        $fields = $index->getFields();
        foreach ($fields as $field_machine_name => $field) {
          if ($field->getPropertyPath() === 'nid') {
            $nid_field = $field_machine_name;
            break;
          }
        }
      }
      if (empty($nid_field)) {
        $init_errors[] = 'No Node ID field found in the Index.';
        return $init_errors;
      }
      $this->nidFieldName = $nid_field;
      $this->index = $index;
      $this->backendConfig = $this->index->getServerInstance()->getBackendConfig()['database_settings'];
    }
    if (empty($options['metric'])) {
      $init_errors[] = 'No Similarity Metric configured in Argument.';
      return $init_errors;
    }
    $this->distanceThreshold = $this->adjustThresholdForMetric($options);
    $this->metric = $this->getMetricSymbol($options['metric']);
    $this->embeddingStrategy = $options['embedding_strategy'];
    // Connect to the VDB.
    if (empty($this->connection)) {
      $ai_vdb_provider_postgres = $this->aiVdbProviderPluginManager->createInstance('postgres');
      $this->connection = $ai_vdb_provider_postgres->getConnection($this->backendConfig["database_name"]);
    }
    return $init_errors;
  }

  /**
   * Returns only the list of similar NIDs that are closer than the threshold.
   *
   * @param int|null $contextNid
   *   The node ID to find buddies for.
   * @param int $limit
   *   Max number of nodes.
   *
   * @return array
   *   The node IDs of similar nodes.
   */
  public function semanticProximityNodeIds(?int $contextNid, int $limit): array {

    try {

      $results = $this->getRagResults($contextNid, $limit);

    }
    catch (VectorSearchException $e) {
      return [];
    }

    // Filter by the distanceThreshold.
    $related_nids = [];
    foreach ($results as $rag_result) {
      if (isset($rag_result['distance']) && $rag_result['distance'] < $this->distanceThreshold) {
        $related_nids[] = $rag_result[$this->nidFieldName];
      }
    }
    return $related_nids;

  }

  /**
   * Get RAG results list.
   *
   * @param int $nid
   *   The context Node ID.
   * @param int $limit
   *   The query LIMIT parameter.
   *
   * @return array
   *   The RAG response.
   *
   * @throws \Drupal\ai_vdb_provider_postgres\Exception\VectorSearchException
   */
  protected function getRagResults(int $nid, int $limit) {

    $collection_name = $this->backendConfig['collection'];

    if ($this->embeddingStrategy == 'contextual_chunks') {
      // Here we calculate the distance between the AVG vector for the context
      // node and individual chunks of all other nodes.
      $query = "WITH scored AS (
        SELECT DISTINCT ON ({$this->nidFieldName})
          {$this->nidFieldName},
          embedding {$this->metric} (SELECT AVG(embedding) FROM {$collection_name} WHERE {$this->nidFieldName} = {$nid}) AS distance
        FROM {$collection_name}
        WHERE {$this->nidFieldName} <> {$nid}
        ORDER BY {$this->nidFieldName}, distance ASC
      )
      SELECT *
      FROM scored
      ORDER BY distance ASC
      LIMIT {$limit};";
    }
    else {
      // embedding_strategy == 'average_pool' means no SQL grouping,
      // no averaging needed.
      $query = "SELECT {$this->nidFieldName}, embedding {$this->metric} (SELECT (embedding) from {$collection_name} where {$this->nidFieldName} = {$nid}) as distance FROM {$collection_name} WHERE {$this->nidFieldName} != {$nid} ORDER BY distance LIMIT {$limit};";
    }

    $result = pg_query(connection: $this->connection, query: $query);
    if (!$result) {
      throw new VectorSearchException(message: pg_last_error(connection: $this->connection));
    }
    return pg_fetch_all(result: $result);
  }

  /**
   * Modifies distance threshold depending on the metric.
   *
   * @param array $options
   *   The options passed from Argument plugin.
   *
   * @return float
   *   The modified distance threshold;
   */
  protected function adjustThresholdForMetric(array $options) {
    $d_cos = $options['distance_threshold'];
    if ($options['metric'] == 'cosine_similarity') {
      return $d_cos;
    }
    elseif ($options['metric'] == 'euclidean_distance') {
      return sqrt($d_cos * 2);
    }
    else {
      // The metric == 'inner_product'.
      return $d_cos - 1;
    }
  }

  /**
   * Gets PostgreSQL metric symbol for Metric Type.
   *
   * @param string $metric
   *   The metric type.
   *
   * @return string
   *   The pgvector metric symbol, <-> or '<=>' or <#>.
   */
  protected function getMetricSymbol(string $metric): string {
    $metric_type = VdbSimilarityMetrics::from(
      $metric
    );
    return match ($metric_type) {
      VdbSimilarityMetrics::EuclideanDistance => '<->',
      VdbSimilarityMetrics::CosineSimilarity => '<=>',
      VdbSimilarityMetrics::InnerProduct => '<#>',
    };
  }

}
