<?php

declare(strict_types=1);

namespace Drupal\opensearch_nlp\Cache;

use Drupal\Core\Cache\CacheBackendInterface;
use Drupal\opensearch_nlp\Service\NLPIngestionService;
use Drupal\Core\Logger\LoggerChannelFactoryInterface;

/**
 * Provides a semantic caching layer using Drupal's cache system.
 */
class SemanticCache {

  /**
   * Default similarity threshold for cache hits (0.85 = 85% similarity).
   */
  private const float SIMILARITY_THRESHOLD = 0.85;

  /**
   * Similarity threshold for displaying similar queries (0.80 = 80% similarity).
   */
  private const float SIMILAR_QUERIES_THRESHOLD = 0.80;

  /**
   * The threshold for cosine similarity to consider two queries as similar.
   *
   * @var float
   *   The threshold for cosine similarity to consider two queries as similar.
   */
  protected $similarityThreshold = self::SIMILARITY_THRESHOLD;

  /**
   * The model ID for the deployed embedding model.
   *
   * @var string
   *   The model ID for the deployed embedding model.
   */
  protected $model;

  /**
   * An array to store cache keys.
   *
   * @var array
   *   An array to store cache keys.
   */
  protected $cacheKeys = [];

  /**
   * The logger service.
   *
   * @var \Drupal\Core\Logger\LoggerChannelInterface
   */
  protected $logger;

  /**
   * Constructs the Semantic Cache service.
   */
  public function __construct(
    /**
     * The embedding model used for semantic search.
     *
     * @var \Drupal\opensearch_nlp\Service\NLPIngestionService
     *   The embedding model used for semantic search.
     */
    protected NLPIngestionService $embeddingModel,
    /**
     * The cache backend service.
     *
     * @var \Drupal\Core\Cache\CacheBackendInterface
     *   The cache backend service.
     */
    protected CacheBackendInterface $cache,
    LoggerChannelFactoryInterface $logger_factory,
  ) {
    $this->logger = $logger_factory->get('opensearch_nlp');
    $deployed_model = $this->embeddingModel->getDeployedModels();
    if (empty($deployed_model)) {
      $this->logger->error('No deployed models found.');
    }
    $this->model = $deployed_model['hits']['hits'][0]['_id'] ?? NULL;
  }

  /**
   * Retrieves a cached response if a semantically similar query exists.
   */
  public function retrieveCachedResponse(string $query, $index_name, $body = []): mixed {
    // Validate the model and extract the embedding data for the query.
    $query_embedding_response = $this->embeddingModel->predictModel($this->model, $query);
    $query_embedding = $this->extractEmbedding($query_embedding_response);

    // Normalize the embedding vector to avoid high similarity issues.
    $query_embedding = $this->normalizeVector($query_embedding);
    // Convert index_name to a string if it's an array.
    $index_name_str = is_array($index_name) ? implode(', ', $index_name) : $index_name;
    if ($query_embedding === []) {
      $this->logger->error('Failed to extract embedding for query: @query', ['@query' => $query]);
      return NULL;
    }

    $request = [
      'body' => $body,
      'index' => is_array($index_name) ? array_values(array_unique($index_name)) : [$index_name],
      'query_text' => $query,
    ];

    // Generate a cache key that includes the index name(s).
    $cache_key = $this->generateCacheKey($request);
    // Retrieve the list of all cache keys.
    $cache_keys = $this->cache->get('semantic_cache_keys');
    if ($cache_keys && isset($cache_keys->data)) {
      foreach ($cache_keys->data as $stored_cache_key) {
        if ($stored_cache_key === $cache_key) {
          $cache_item = $this->cache->get($stored_cache_key);
          if ($cache_item) {
            $cached_data = $cache_item->data;
            // Increment the hits count.
            $cached_data['hits'] = isset($cached_data['hits']) ? $cached_data['hits'] + 1 : 1;
            // Update the cache with the incremented hits count.
            $this->cache->set($cache_key, $cached_data, time() + 86400);
            // Compute similarity between cached query and new query.
            $similarity = $this->cosineSimilarity($query_embedding, $cached_data['embedding']);
            if ($similarity >= $this->similarityThreshold) {
              $this->logger->info('Cache hit for query: @query with similarity: @similarity and index:@index', [
                '@query' => $query,
                '@similarity' => $similarity,
                '@index' => $cached_data['index'],
              ]);
              return $cached_data['results'];
            }
          }
        }
      }
    }
    $this->logger->info(
      'No cache hit for query: @query and index:@index',
      [
        '@query' => $query,
        '@index' => $index_name_str,
      ]
    );
    return NULL;
  }

  /**
   * Normalizes a vector to unit length.
   *
   * @param array $vec
   *   The vector to normalize.
   *
   * @return array
   *   The normalized vector.
   */
  private function normalizeVector(array $vec): array {
    $norm = sqrt(array_sum(array_map(fn($x): int|float => $x * $x, $vec)));
    return $norm !== 0.0 ? array_map(fn($x): float => $x / $norm, $vec) : $vec;
  }

  /**
   * Stores the search response in Drupal's cache.
   */
  public function storeResponse(string $query, $response, $index_name, $search_type = 'semantic', $body = []): void {
    // Validate the model and extract the embedding data.
    $query_embedding_response = $this->embeddingModel->predictModel($this->model, $query);
    $query_embedding = $this->extractEmbedding($query_embedding_response);
    // Convert index_name to a string if it's an array.
    $index_name_str = is_array($index_name) ? implode(', ', $index_name) : $index_name;
    // Check if the embedding was successfully extracted.
    if (!$query_embedding) {
      $this->logger->error('Failed to extract embedding for query: @query', ['@query' => $query]);
      return;
    }
    $request = [
      'body' => $body,
      'index' => is_array($index_name) ? array_values(array_unique($index_name)) : [$index_name],
      'query_text' => $query,
    ];
    // Generate a cache key that includes the index name(s).
    $cache_key = $this->generateCacheKey($request);

    // Store the response in the cache.
    $this->cache->set($cache_key, [
      'query' => $query,
      'embedding' => $query_embedding,
      'results' => $response,
      'index' => $index_name_str,
      'hits' => 1,
      'search_type' => $search_type,
      'date' => date('Y-m-d H:i:s'),
      'body' => $body,
    ], time() + 86400);

    // Update the list of cache keys.
    $cache_keys = $this->cache->get('semantic_cache_keys');
    $keys = $cache_keys && isset($cache_keys->data) ? $cache_keys->data : [];
    if (!in_array($cache_key, $keys)) {
      $keys[] = $cache_key;
      $this->cache->set('semantic_cache_keys', $keys, CacheBackendInterface::CACHE_PERMANENT);
    }
  }

  /**
   * Extracts the embedding vector from the model validation response.
   *
   * @param array $response
   *   The response from the predictModel method.
   *
   * @return array|null
   *   The embedding vector, or NULL if extraction fails.
   */
  private function extractEmbedding(array $response): mixed {
    if (isset($response['inference_results'][0]['output'][0]['data'])) {
      return $response['inference_results'][0]['output'][0]['data'];
    }
    return NULL;
  }

  /**
   * Generates a unique cache key for a query and index.
   *
   * @param array $request
   *   The request data.
   *
   * @return string
   *   The generated cache key.
   */
  private function generateCacheKey(array $request): string {
    // Normalize the request array recursively.
    $normalized_request = $this->normalizeArrayRecursive($request);
    // Encode as JSON string for hashing.
    $serialized = json_encode($normalized_request, JSON_UNESCAPED_UNICODE | JSON_UNESCAPED_SLASHES);
    // Return SHA-256 hashed cache key.
    return 'semantic_cache:' . hash('sha256', $serialized);
  }

  /**
   * Recursively normalizes an array by sorting keys and normalizing values.
   *
   * @param array $array
   *   The array to normalize.
   *
   * @return array
   *   The normalized array.
   */
  private function normalizeArrayRecursive(array $array): array {
    ksort($array);
    foreach ($array as &$value) {
      if (is_array($value)) {
        $value = $this->normalizeArrayRecursive($value);
      }
    }
    return $array;
  }

  /**
   * Computes cosine similarity between two vectors.
   */
  private function cosineSimilarity(array $vec1, array $vec2): float|int {
    $dot_product = array_sum(array_map(fn($a, $b): int|float => $a * $b, $vec1, $vec2));
    $magnitude1 = sqrt(array_sum(array_map(fn($x): int|float => $x * $x, $vec1)));
    $magnitude2 = sqrt(array_sum(array_map(fn($x): int|float => $x * $x, $vec2)));
    return ($magnitude1 * $magnitude2 !== 0.0) ? $dot_product / ($magnitude1 * $magnitude2) : 0;
  }

  /**
   * Cleans up cache keys by removing invalid ones.
   */
  public function cleanupCacheKeys(): void {
    $cache_keys = $this->cache->get('semantic_cache_keys');
    if ($cache_keys && isset($cache_keys->data)) {
      $valid_keys = [];
      foreach ($cache_keys->data as $cache_key) {
        if ($this->cache->get($cache_key)) {
          $valid_keys[] = $cache_key;
        }
      }
      $this->cache->set('semantic_cache_keys', $valid_keys, CacheBackendInterface::CACHE_PERMANENT);
    }
  }

  /**
   * Clears all cached data.
   */
  public function clearAllCachedData(): void {
    // Retrieve the list of semantic cache keys.
    $cache_keys = $this->cache->get('semantic_cache_keys');
    if ($cache_keys && isset($cache_keys->data)) {
      foreach ($cache_keys->data as $cache_key) {
        // Remove each cache entry.
        $this->cache->delete($cache_key);
      }
      // Clear the list of semantic cache keys.
      $this->cache->delete('semantic_cache_keys');
      $this->logger->info('All semantic cache data has been cleared.');
    }
    else {
      $this->logger->info('No semantic cache data found to clear.');
    }
  }

  /**
   * Retrieves all cached data for display.
   *
   * @return array
   *   An array of cached data including keys, indexes, queries, and results.
   */
  public function queryCachedData(): array {
    $cache_keys = $this->cache->get('semantic_cache_keys');
    if (!$cache_keys || empty($cache_keys->data)) {
      return [];
    }
    $cache_keys_data = $cache_keys->data ?? [];
    $cached_items = $this->cache->getMultiple($cache_keys_data) ?? [];
    $records = [];
    foreach ($cached_items as $cache_key => $cache_item) {
      if (!$cache_item || empty($cache_item->data['embedding'])) {
        continue;
      }

      // Normalize to object for safety.
      $data = is_array($cache_item->data) ? (object) $cache_item->data : $cache_item->data;
      $normalized_query = strtolower(trim($data->query));
      if (isset($records[$normalized_query])) {
        $records[$normalized_query]['hits'] += $data->hits ?? 1;
        continue;
      }

      // First time we see this query.
      $normalized_embedding = $this->normalizeVector($data->embedding);
      $records[$normalized_query] = [
        'cache_key' => $cache_key,
        'search_type' => $data->search_type ?? 'semantic',
        'indexes' => $data->index ?? '-',
        'query' => $data->query,
        'hits' => $data->hits ?? 0,
        'embedding' => $normalized_embedding,
        'date' => $data->date ?? '-',
      ];
    }

    // Step 3: Compute similarities once (pairwise).
    $results = [];
    foreach ($records as $queryA => $recordA) {
      $similar_queries = [];
      foreach ($records as $queryB => $recordB) {
        if ($queryA === $queryB) {
          continue;
        }
        $similarity = $this->cosineSimilarity(
          $recordA['embedding'],
          $recordB['embedding']
        );

        if ($similarity >= self::SIMILAR_QUERIES_THRESHOLD && !isset($similar_queries[$queryB])) {
          $similar_queries[$queryB] = [
            'query' => $recordB['query'],
            'similarity' => $similarity,
          ];
        }
      }
      $recordA['similar_queries'] = array_values($similar_queries);
      unset($recordA['embedding']);
      $results[] = $recordA;
    }
    return $results;
  }

}
