import json
import os
import logging
import re
from typing import Dict, Any, Optional

from dotenv import load_dotenv
from langchain_community.chat_models import ChatOllama
from langchain_core.callbacks.base import BaseCallbackHandler   


from data_search import retrieve_semantic_docs  # Your custom search logic

# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# Load environment variables
load_dotenv()


class StreamingToSocketHandler(BaseCallbackHandler):
    def __init__(self, send_token_fn):
        self.send_token_fn = send_token_fn

    def on_llm_new_token(self, token: str, **kwargs):
        self.send_token_fn(token)


class SimpleDeepSeekRAG:
    """Agentic RAG system for DeepSeek-R1 with fixed search + answer behavior"""

    def __init__(self, stream_func, model_name: str = "deepseek-r1:8b"):
        self.ollama_host = os.getenv("OLLAMA_HOST", "http://localhost:11434").strip()
        self.llm = ChatOllama(
            model=model_name,
            base_url=self.ollama_host,
            temperature=0.5,
            timeout=120,
            num_predict=1500,
            streaming=True,
            callbacks=[StreamingToSocketHandler(stream_func)]
        )
        logger.info(f"Initialized DeepSeek RAG with {model_name}")
        
        
    def query(self, question: str) -> Dict[str, Any]:
        try:
            CHAR_CAP = 600  # <-- easy to tweak

            search_prompt = f"""
            Based on this question, and considering that this is a chat strictly about Brazilian legal decisions of 2024 on elections.
            What should I search for in a document database using the KNN search?

            Question: {question}

            Provide only the main search terms in Portuguese, separated by spaces.
            Write the query inside <search> </search> tags.
            """

            raw_search_response = self.llm.invoke(search_prompt).content.strip()
            search_query = self._extract_search_terms(raw_search_response)

            logger.info(f"Generated search query: {search_query}")

            try:
                search_results = retrieve_semantic_docs(search_query)
                logger.info(f"Search returned: {len(search_results) if isinstance(search_results, list) else 1} results")
            except Exception as e:
                logger.error(f"Search failed: {e}")
                search_results = []

            formatted_results = "No documents found."
            source_list = []
            
            total_content = ""

            if search_results:
                if not isinstance(search_results, list):
                    search_results = [search_results]

                formatted_results = "FOUND DOCUMENTS:\n"
                for i, doc in enumerate(search_results, 1):
                    if isinstance(doc, dict):
                        doc_id = doc.get('numero_unico')
                        doc_date = doc.get('data_publicacao')
                        content = doc.get('conteudo_html') or "No content"
                        if len(content) > CHAR_CAP:
                            content = content[:CHAR_CAP] + "..."
                            
                        total_content += content+'\n\n\n'
                        
                        formatted_results += (
                            f"\nDocument ID: {doc_id}\nDate: {doc_date}\nContent: {content}\n"
                        )
                        source_list.append({"id": doc_id, "date": doc_date})

            answer_prompt = f"""
            Based on the search results below, answer this question **in Brazilian Portuguese (português do Brasil)**.

            Question: {question}

            Search Results:
            {formatted_results}

            IMPORTANT RULES:
            - Only use information from the search results above
            - Say everything you found in the document, but priorize answering the question. 
            - Be specific about what the documents say about the topic
            - You must mention the document IDs in your answer
            - When reasoning about the documents, make sure to open and close the <think> tags
            - Explain clearly and directly, without additional labels or markings
            - ONLY send the final answer inside <answer> tags in this exact format:
            <answer>
                your detailed answer based on search results, written in Brazilian Portuguese
            </answer>

            """

            response = self.llm.invoke(answer_prompt).content.strip()
            #logger.info(f"Generated response: {response[:100]}...")

            json_response = self._extract_json(response)

            return {
                "success": True,
                "response": json_response if json_response else response,
                "raw_output": response,
                "search_query": search_query,
                "search_results_count": len(search_results),
                "sources": source_list,
                "total_content": total_content
            }

        except Exception as e:
            error_msg = f"Query failed: {str(e)}"
            logger.error(error_msg)
            return {
                "success": False,
                "error": error_msg,
                "response": None
            }


    def _extract_search_terms(self, raw_response: str) -> str:
	    try:
	        clean_text = re.sub(r'<think>.*?</think>', '', raw_response, flags=re.DOTALL).strip()

	        if not clean_text or len(clean_text) < 5:
	            think_match = re.search(r'<think>(.*?)</think>', raw_response, re.DOTALL)
	            if think_match:
	                thinking_content = think_match.group(1)
	                lines = thinking_content.split('\n')
	                terms = [
	                    line[2:].split('(')[0].strip()
	                    for line in lines
	                    if line.strip().startswith('- ')
	                ]
	                if terms:
	                    return ' '.join(terms[:3])
	                return 'documento'  # <- neutral fallback

	        lines = [line.strip() for line in clean_text.split('\n') if line.strip()]
	        if lines:
	            first_line = re.sub(r'\s+', ' ', lines[0])
	            return first_line[:100]

	        return clean_text or 'documento'

	    except Exception as e:
	        logger.error(f"Search term extraction failed: {e}")
	        return 'documento'


    def _extract_json(self, text: str) -> Optional[Dict[str, Any]]:
        try:
            clean_text = re.sub(r'<think>.*?</think>', '', text, flags=re.DOTALL).strip()
            patterns = [
                r'\{[^{}]*(?:\{[^{}]*\}[^{}]*)*\}',
                r'\{\s*"answer"[^}]+\}',
            ]
            for pattern in patterns:
                matches = re.findall(pattern, clean_text, re.DOTALL)
                for match in matches:
                    try:
                        parsed = json.loads(match.strip())
                        if isinstance(parsed, dict) and ("answer" in parsed or "response" in parsed):
                            return parsed
                    except json.JSONDecodeError:
                        continue
            return None
        except Exception as e:
            logger.error(f"JSON extraction failed: {e}")
            return None

