<?php

namespace Drupal\langfuse_ai_logging\EventSubscriber;

use Drupal\ai\Event\PreGenerateResponseEvent;
use Drupal\ai\Event\PostGenerateResponseEvent;
use Drupal\ai\Event\PostStreamingResponseEvent;
use Drupal\Component\Datetime\TimeInterface;
use Drupal\Core\Logger\LoggerChannelFactoryInterface;
use Drupal\Core\Logger\LoggerChannelInterface;
use Drupal\Core\Session\AccountProxyInterface;
use Drupal\langfuse\LangFuseClientInterface;
use Dropsolid\LangFuse\Observability\ObservationInterface;
use Symfony\Component\EventDispatcher\EventSubscriberInterface;
use Symfony\Component\HttpKernel\KernelEvents;

/**
 * Event subscriber for comprehensive AI interaction tracing with LangFuse.
 *
 * This subscriber implements the proper LangFuse integration pattern:
 * 1. Pre-generation: Create trace when AI request starts
 * 2. Post-generation: End trace and capture response data
 * 3. Streaming: Handle streaming responses appropriately
 * Uses singleton pattern via LangFuse client's current trace management.
 */
class LangFuseAiLoggingSubscriber implements EventSubscriberInterface {

  /**
   * The LangFuse client service.
   *
   * @var \Drupal\langfuse\LangFuseClientInterface
   */
  protected LangFuseClientInterface $langFuseClient;

  /**
   * The logger channel.
   *
   * @var \Drupal\Core\Logger\LoggerChannelInterface
   */
  protected LoggerChannelInterface $logger;

  /**
   * The current user service.
   *
   * @var \Drupal\Core\Session\AccountProxyInterface
   */
  protected AccountProxyInterface $currentUser;

  /**
   * The time service.
   *
   * @var \Drupal\Component\Datetime\TimeInterface
   */
  protected TimeInterface $time;

  /**
   * Map of request thread IDs to LangFuse trace IDs.
   *
   * @var array
   */
  protected array $threadToTraceMap = [];

  /**
   * Constructs a new LangFuseAiLoggingSubscriber.
   *
   * @param \Drupal\langfuse\LangFuseClientInterface $langfuse_client
   *   The LangFuse client service.
   * @param \Drupal\Core\Logger\LoggerChannelFactoryInterface $logger_factory
   *   The logger channel factory.
   * @param \Drupal\Core\Session\AccountProxyInterface $current_user
   *   The current user service.
   * @param \Drupal\Component\Datetime\TimeInterface $time
   *   The time service.
   */
  public function __construct(
    LangFuseClientInterface $langfuse_client,
    LoggerChannelFactoryInterface $logger_factory,
    AccountProxyInterface $current_user,
    TimeInterface $time,
  ) {
    $this->langFuseClient = $langfuse_client;
    $this->logger = $logger_factory->get('langfuse_ai_logging');
    $this->currentUser = $current_user;
    $this->time = $time;
  }

  /**
   * {@inheritdoc}
   */
  public static function getSubscribedEvents(): array {
    return [
      PreGenerateResponseEvent::EVENT_NAME => ['onPreGenerateResponse', 0],
      PostGenerateResponseEvent::EVENT_NAME => ['onPostGenerateResponse', 0],
      PostStreamingResponseEvent::EVENT_NAME => ['onPostStreamingResponse', 0],
      KernelEvents::TERMINATE => ['onKernelTerminate', 0],
    ];
  }

  /**
   * Handles the pre-generate response event.
   *
   * Creates a LangFuse trace when AI request starts, or reuses existing trace
   * if one already exists for this Drupal request.
   *
   * @param \Drupal\ai\Event\PreGenerateResponseEvent $event
   *   The pre-generation event.
   */
  public function onPreGenerateResponse(PreGenerateResponseEvent $event): void {
    if (!$this->langFuseClient->isConfigured()) {
      $this->logger->debug('LangFuse not configured, skipping AI request tracing');
      return;
    }

    try {
      $threadId = $event->getRequestThreadId();
      $providerId = $event->getProviderId();
      $operationType = $event->getOperationType();
      $modelId = $event->getModelId();

      // Check if we already have a trace for this Drupal request.
      $trace = $this->langFuseClient->getCurrentTrace();

      if (!$trace) {
        // Create generic trace name for the entire Drupal request.
        // Individual operations get their own descriptive generation names.
        $traceName = 'drupal_ai_request';

        // Create new trace for this Drupal request.
        $trace = $this->langFuseClient->createTrace(
          $traceName,
          (string) $this->currentUser->id(),
          session_id(),
          [
            'drupal_request' => TRUE,
            'first_ai_provider' => $providerId,
            'first_operation_type' => $operationType,
            'request_timestamp' => time(),
          ],
          ['ai', 'drupal', 'multi_operation'],
          NULL,
          NULL
        );
        // Set initial trace input from first AI operation.
        $trace->update(input: $event->getInput() ? $event->getInput()->toString() : NULL);

        $this->logger->info('Created new LangFuse trace @trace_id (@name) for Drupal AI request', [
          '@trace_id' => $trace->getId(),
          '@name' => $traceName,
        ]);
      }

      // Create generation metadata with AI-specific details.
      $generationMetadata = [
        'provider_id' => json_encode($providerId),
        'operation_type' => json_encode($operationType),
        'model_id' => json_encode($modelId),
        'configuration_preview' => $this->createSafePreview(json_encode($event->getConfiguration())),
        'input_preview' => $this->createSafePreview($event->getInput() ? $event->getInput()->toString() : ''),
        'drupal_request_time' => json_encode($this->time->getRequestTime()),
        'is_embedding_operation' => json_encode(in_array($operationType, ['embedding', 'embeddings'])),
        'thread_id' => json_encode($threadId),
      ];

      // Add input metadata if available.
      if ($event->getInput()) {
        $generationMetadata['input_type'] = get_class($event->getInput());
        $inputString = $event->getInput()->toString();
        $generationMetadata['input_preview'] = $this->createSafePreview($inputString);
        $generationMetadata['input_length'] = mb_strlen($inputString);
      }

      // Create generation within the shared trace.
      $generation_name = match($operationType) {
        'embedding', 'embeddings' => 'text-embedding',
        'chat_completion', 'completion' => 'chat-completion',
        'text_completion' => 'text-completion',
        default => sprintf('%s_%s', $operationType, $providerId),
      };

      // Ensure model parameters are in correct format (associative array).
      $modelParameters = $event->getConfiguration();

      // Debug: Log what we're getting from the AI module.
      $this->logger->debug('Model parameters from AI module: @params', [
        '@params' => json_encode($modelParameters),
      ]);

      if (is_array($modelParameters) && !empty($modelParameters)) {
        // Ensure it's a proper associative array for LangFuse API.
        // If it's a list (indexed array), convert to empty array.
        if (array_is_list($modelParameters)) {
          $this->logger->warning('Model parameters are indexed array, converting to empty array for LangFuse');
          $modelParameters = [];
        }
      }
      else {
        // Fallback to empty array if no configuration.
        $modelParameters = [];
      }

      $trace->createGeneration(
        $generation_name,
        $modelId,
        $modelParameters,
        $generationMetadata,
        $event->getInput() ? $event->getInput()->toString() : NULL
      );

      // Map thread ID to trace for later reference.
      $this->threadToTraceMap[$threadId] = $trace->getId();

      $this->logger->info('Created generation for @operation_type via @provider in trace @trace_id (thread: @thread_id)', [
        '@operation_type' => $operationType,
        '@provider' => $providerId,
        '@trace_id' => $trace->getId(),
        '@thread_id' => $threadId,
      ]);
    }
    catch (\Exception $e) {
      // Log error but don't interfere with AI request.
      $this->logger->error('Failed to create LangFuse trace for AI request: @error', [
        '@error' => $e->getMessage(),
      ]);
    }
  }

  /**
   * Handles the post-generate response event.
   *
   * Completes the generation within the shared trace and captures response
   * data. Only ends the trace if this is the last AI operation in the request.
   *
   * @param \Drupal\ai\Event\PostGenerateResponseEvent $event
   *   The post-generation event.
   */
  public function onPostGenerateResponse(PostGenerateResponseEvent $event): void {
    if (!$this->langFuseClient->isConfigured()) {
      return;
    }

    $trace = $this->getTraceFromEvent($event->getRequestThreadId(), 'post-generation');
    if (!$trace) {
      return;
    }

    try {
      // Complete the generation with response data.
      $this->completeGeneration($trace, $event);

      // Update trace metadata with latest operation info.
      $trace->updateMetadata([
        'last_response_timestamp' => time(),
        'total_operations' => count($trace->getObservations()),
      ]);

      // Clean up thread mapping.
      unset($this->threadToTraceMap[$event->getRequestThreadId()]);

      $this->logger->info('Completed generation for AI response @thread_id in trace @trace_id', [
        '@trace_id' => $trace->getId(),
        '@thread_id' => $event->getRequestThreadId(),
      ]);

      // Note: We don't end the trace here as other AI operations might follow.
      // The trace will be ended when the Drupal request finishes.
    }
    catch (\Exception $e) {
      $this->logger->error('Failed to complete LangFuse generation for AI response: @error', [
        '@error' => $e->getMessage(),
      ]);
    }
  }

  /**
   * Handles the post-streaming response event.
   *
   * Captures final streaming results and associates with trace.
   *
   * @param \Drupal\ai\Event\PostStreamingResponseEvent $event
   *   The post-streaming event.
   */
  public function onPostStreamingResponse(PostStreamingResponseEvent $event): void {
    if (!$this->langFuseClient->isConfigured()) {
      return;
    }

    $trace = $this->getTraceFromEvent($event->getRequestThreadId(), 'streaming', 'debug');
    if (!$trace) {
      return;
    }

    try {
      // Update trace with streaming completion metadata.
      $this->finalizeTrace($trace, $event, 'streaming');

      $this->logger->info('Updated LangFuse trace @trace_id with streaming completion for thread @thread_id', [
        '@trace_id' => $trace->getId(),
        '@thread_id' => $event->getRequestThreadId(),
      ]);
    }
    catch (\Exception $e) {
      $this->logger->error('Failed to update LangFuse trace for streaming response: @error', [
        '@error' => $e->getMessage(),
      ]);
    }
  }

  /**
   * Retrieves a trace from the thread mapping.
   *
   * @param string $threadId
   *   The request thread ID.
   * @param string $context
   *   Context for logging (e.g., 'post-generation', 'streaming').
   * @param string $logLevel
   *   Log level for missing trace ('warning' or 'debug').
   *
   * @return mixed|null
   *   The trace object or NULL if not found.
   */
  private function getTraceFromEvent(string $threadId, string $context, string $logLevel = 'warning') {
    if (!isset($this->threadToTraceMap[$threadId])) {
      $this->logger->{$logLevel}('No LangFuse trace found for thread @thread_id in @context', [
        '@thread_id' => $threadId,
        '@context' => $context,
      ]);
      return NULL;
    }

    $traceId = $this->threadToTraceMap[$threadId];
    $trace = $this->langFuseClient->getTrace($traceId);

    if (!$trace) {
      $this->logger->warning('LangFuse trace @trace_id not found for thread @thread_id', [
        '@trace_id' => $traceId,
        '@thread_id' => $threadId,
      ]);
      return NULL;
    }

    return $trace;
  }

  /**
   * Completes the generation within a trace.
   *
   * @param mixed $trace
   *   The trace object.
   * @param \Drupal\ai\Event\PostGenerateResponseEvent $event
   *   The post-generation event.
   */
  private function completeGeneration($trace, PostGenerateResponseEvent $event): void {
    $generations = $trace->getObservations();
    /** @var \Dropsolid\LangFuse\Observability\ObservationInterface|null $generation */
    $generation = NULL;
    $targetThreadId = $event->getRequestThreadId();

    // Find the correct generation by matching thread ID.
    foreach ($generations as $obs) {
      if ($obs instanceof ObservationInterface && $obs->getType() === 'generation') {
        $metadata = $obs->getMetadata();
        if (isset($metadata['thread_id'])) {
          $obsThreadId = json_decode($metadata['thread_id'], TRUE);
          if ($obsThreadId === $targetThreadId) {
            $generation = $obs;
            break;
          }
        }
      }
    }

    if (!$generation) {
      $this->logger->warning('No matching generation found for thread @thread_id', [
        '@thread_id' => $targetThreadId,
      ]);
      return;
    }

    // At this point, $generation is guaranteed to be an ObservationInterface.
    assert($generation instanceof ObservationInterface);

    // Prepare output data - trust the AI module's normalization.
    $outputData = [];

    if ($event->getOutput()) {
      try {
        // Trust the AI module's normalized output.
        $chatMessage = $event->getOutput()->getNormalized();

        // Handle different output types safely.
        if (is_array($chatMessage)) {
          $outputData['output'] = ['raw' => $chatMessage];
        }
        elseif (is_object($chatMessage) && method_exists($chatMessage, 'getText')) {
          $outputData['output'] = $chatMessage->getText();
          // Set the NEXT update as the output - whatever it is.
          $trace->update(output: $chatMessage->getText());
        }
        else {
          // Handle string or other scalar types.
          $outputData['output'] = (string) $chatMessage;
          // Set the NEXT update as the output - whatever it is.
          $trace->update(output: (string) $chatMessage);
        }
      }
      catch (\Exception $e) {
        // Fallback to raw output if normalization fails.
        $this->logger->warning('Failed to get normalized output, using raw: @error', [
          '@error' => $e->getMessage(),
        ]);
        $outputData['output'] = $event->getOutput()->getRawOutput();
      }

      // Extract usage from raw for LangFuse analytics.
      $raw = $event->getOutput()->getRawOutput();
      if (is_array($raw) && isset($raw['usage'])) {
        // Use LangFuse standard 'usage_details' field for token analytics.
        if ($event->getOperationType() === 'embeddings') {
          // For embeddings, we want different usage format.
          $outputData['usage_details']['input'] = $raw['usage']['prompt_tokens'] ?? 0;
        }
        else {
          $outputData['usage_details'] = $raw['usage'];
        }
      }
    }
    else {
      $outputData['output'] = NULL;
    }

    // Add debug data if available (as JSON string).
    if ($event->getDebugData()) {
      $outputData['debug_details'] = json_encode($event->getDebugData());
    }

    // Add metadata if available (as JSON string).
    if ($event->getAllMetadata()) {
      $outputData['metadata'] = json_encode($event->getAllMetadata());
    }

    $generation->end($outputData);
  }

  /**
   * Finalizes a trace with metadata and ends it.
   *
   * @param mixed $trace
   *   The trace object.
   * @param mixed $event
   *   The event object (PostGenerateResponseEvent or
   *   PostStreamingResponseEvent).
   * @param string $type
   *   The type of finalization ('response' or 'streaming').
   */
  private function finalizeTrace($trace, $event, string $type): void {
    $metadata = ["{$type}_timestamp" => time()];

    if ($type === 'response') {
      $metadata['response_type'] = $event->getOutput() ? get_class($event->getOutput()) : NULL;
      $metadata['debug_data_json'] = json_encode($event->getDebugData());
      $metadata['event_metadata_json'] = json_encode($event->getAllMetadata());
      $trace->end();
    }
    elseif ($type === 'streaming') {
      $metadata['streaming_complete'] = TRUE;
      $metadata['final_output_preview'] = $this->createSafePreview($event->getOutput());
      $metadata['final_metadata_json'] = json_encode($event->getAllMetadata());
    }

    $trace->updateMetadata($metadata);
  }

  /**
   * Creates a safe preview of complex data for metadata.
   *
   * @param mixed $data
   *   The data to create a preview for.
   *
   * @return string|null
   *   A safe string preview or NULL.
   */
  private function createSafePreview($data): ?string {
    if ($data === NULL) {
      return NULL;
    }

    if (is_scalar($data)) {
      return (string) $data;
    }

    if (is_object($data) && method_exists($data, '__toString')) {
      $string = (string) $data;
      return mb_substr($string, 0, 200) . (mb_strlen($string) > 200 ? '...' : '');
    }

    if (is_array($data) || is_object($data)) {
      $json = json_encode($data);
      return mb_substr($json, 0, 200) . (mb_strlen($json) > 200 ? '...' : '');
    }

    return 'Complex data: ' . gettype($data);
  }

  /**
   * Handles the kernel terminate event.
   *
   * Finalizes any active trace when the Drupal request ends.
   *
   * @param mixed $event
   *   The kernel terminate event.
   */
  public function onKernelTerminate($event): void {
    if (!$this->langFuseClient->isConfigured()) {
      return;
    }

    $trace = $this->langFuseClient->getCurrentTrace();
    if (!$trace) {
      return;
    }

    try {
      // Update trace with final metadata and capture last output.
      $observations = $trace->getObservations();
      $lastOutput = NULL;

      // Find the last generation with output to use as trace output.
      foreach (array_reverse($observations) as $obs) {
        if (method_exists($obs, 'getType') && $obs->getType() === 'generation') {
          $obsOutput = method_exists($obs, 'getOutput') ? $obs->getOutput() : NULL;
          if ($obsOutput !== NULL) {
            $lastOutput = $obsOutput;
            break;
          }
        }
      }

      $trace->updateMetadata([
        'drupal_request_complete' => TRUE,
        'final_timestamp' => time(),
        'total_ai_operations' => count($observations),
        // Set trace output from last generation.
        'output' => $lastOutput ? json_encode($lastOutput) : NULL,
      ]);

      // End the trace.
      $trace->end();

      // Clear the current trace.
      $this->langFuseClient->clearCurrentTrace();

      // Sync traces to LangFuse.
      $this->langFuseClient->syncTraces();

      $this->logger->info('Finalized LangFuse trace @trace_id at end of Drupal request', [
        '@trace_id' => $trace->getId(),
      ]);
    }
    catch (\Exception $e) {
      $this->logger->error('Failed to finalize LangFuse trace: @error', [
        '@error' => $e->getMessage(),
      ]);
    }
  }

}
