import json
import argparse
import logging
from bs4 import BeautifulSoup
from typing import Dict, List, Any, Optional, Union

# Set up logging
import os

# Configure logging for console output only
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(message)s',
    handlers=[
        logging.StreamHandler()  # Keep console output only
    ]
)
logger = logging.getLogger(__name__)

class HTMLToTiptapConverter:
    """Convert HTML to Tiptap JSON format."""
    
    def __init__(self, config: Optional[Dict] = None):
        """Initialize the converter with optional configuration."""
        self.config = config or {}
        
    def convert(self, html_content: str) -> Dict:
        """Convert HTML to Tiptap format."""
        # Initialize Tiptap document structure
        tiptap_doc = {
            "type": "doc",
            "content": []
        }
        
        # Parse HTML
        try:
            soup = BeautifulSoup(html_content, 'html.parser')
            
            # Process body or the entire document if no body
            body = soup.body or soup
            
            # Process each child element
            for element in body.children:
                if element.name:  # Skip NavigableString objects
                    converted = self._convert_element(element)
                    if converted:
                        if isinstance(converted, list):
                            tiptap_doc["content"].extend(converted)
                        else:
                            tiptap_doc["content"].append(converted)
            
            # Fix empty text nodes
            self._fix_empty_text_nodes(tiptap_doc)
            
            # Add a paragraph if document is empty
            if not tiptap_doc["content"]:
                tiptap_doc["content"].append(self._create_paragraph("\u00A0"))
                
            return tiptap_doc
            
        except Exception as e:
            logger.error(f"Error parsing HTML: {e}")
            # Return a minimal valid document on error
            return {
                "type": "doc",
                "content": [self._create_paragraph("Error converting HTML")]
            }
    
    def _fix_empty_text_nodes(self, node: Dict) -> None:
        """Recursively fix empty text nodes in a document by replacing empty strings with non-breaking spaces."""
        # If this is a text node with empty content, replace it with a non-breaking space
        if node.get("type") == "text" and node.get("text") == "":
            node["text"] = "\u00A0"  # Non-breaking space
            return
            
        # Recursively process any content arrays
        if "content" in node and isinstance(node["content"], list):
            for child in node["content"]:
                if isinstance(child, dict):
                    self._fix_empty_text_nodes(child)
    
    def _convert_element(self, element) -> Optional[Union[Dict, List[Dict]]]:
        """Convert HTML element to Tiptap node."""
        tag_name = element.name.lower() if element.name else None
        
        # Skip comment nodes and script/style tags
        if tag_name is None or tag_name in ['script', 'style', 'head', 'meta', 'link']:
            return None
            
        # Handle different HTML elements
        if tag_name == 'p':
            return self._handle_paragraph(element)
        elif tag_name in ['h1', 'h2', 'h3', 'h4', 'h5', 'h6']:
            return self._handle_heading(element)
        elif tag_name == 'table':
            return self._handle_table(element)
        elif tag_name in ['ul', 'ol']:
            return self._handle_list(element)
        elif tag_name == 'li':
            # Individual list items should be processed in the context of their parent
            return self._handle_list_item(element)
        elif tag_name == 'blockquote':
            return self._handle_blockquote(element)
        elif tag_name == 'pre':
            return self._handle_code_block(element)
        elif tag_name == 'hr':
            return self._handle_horizontal_rule()
        elif tag_name == 'img':
            return self._handle_image(element)
        elif tag_name == 'a':
            return self._handle_link(element)
        elif tag_name == 'br':
            return self._handle_break()
        elif tag_name in ['div', 'section', 'article', 'main', 'header', 'footer']:
            # Container elements - process their children
            return self._handle_container(element)
        else:
            # Default handling for other elements
            return self._handle_generic(element)
    
    def _handle_paragraph(self, element) -> Dict:
        """Convert HTML paragraph to Tiptap paragraph node."""
        return {
            "type": "paragraph",
            "content": self._process_inline_content(element)
        }
    
    def _handle_heading(self, element) -> Dict:
        """Convert HTML heading to Tiptap heading node."""
        level = int(element.name[1])
        return {
            "type": "heading",
            "attrs": {
                "level": level
            },
            "content": self._process_inline_content(element)
        }
    
    def _handle_table(self, element) -> Dict:
        """Convert HTML table to Tiptap table node."""
        rows = []
        
        # Process tbody, thead, and direct tr children
        for section in element.find_all(['tbody', 'thead', 'tfoot']) + [element]:
            for tr in section.find_all('tr', recursive=False):
                row_content = []
                
                for cell in tr.find_all(['td', 'th'], recursive=False):
                    # Get cell attributes
                    colspan = int(cell.get('colspan', 1))
                    rowspan = int(cell.get('rowspan', 1))
                    
                    # Process cell content
                    cell_content = []
                    
                    # Handle nested elements in the cell
                    if cell.find(['p', 'h1', 'h2', 'h3', 'h4', 'h5', 'h6', 'ul', 'ol']):
                        for child in cell.children:
                            if child.name:
                                converted = self._convert_element(child)
                                if converted:
                                    if isinstance(converted, list):
                                        cell_content.extend(converted)
                                    else:
                                        cell_content.append(converted)
                    else:
                        # Simple text content
                        cell_content = [self._create_paragraph_from_content(
                            self._process_inline_content(cell)
                        )]
                    
                    # Create cell with proper attributes
                    cell_node = {
                        "type": "tableCell",
                        "content": cell_content or [self._create_paragraph("\u00A0")]
                    }
                    
                    # Add colspan/rowspan only if they're not the default values
                    attrs = {}
                    if colspan > 1:
                        attrs["colspan"] = colspan
                    if rowspan > 1:
                        attrs["rowspan"] = rowspan
                    
                    if attrs:
                        cell_node["attrs"] = attrs
                    
                    row_content.append(cell_node)
                
                if row_content:  # Only add rows with content
                    rows.append({
                        "type": "tableRow",
                        "content": row_content
                    })
        
        # Ensure we have at least one row
        if not rows:
            # Create an empty table with one cell
            rows = [{
                "type": "tableRow",
                "content": [{
                    "type": "tableCell",
                    "content": [self._create_paragraph("\u00A0")]
                }]
            }]
        
        return {
            "type": "table",
            "content": rows
        }
    
    def _handle_list(self, element) -> Dict:
        """Convert HTML ul/ol to Tiptap bulletList/orderedList node."""
        list_type = "bulletList" if element.name.lower() == 'ul' else "orderedList"
        list_items = []
        
        # Log the list type and the start number if it's an ordered list
        start_attr = element.get('start')
        logger.info(f"Processing {list_type}, HTML start attribute: {start_attr}")
        
        for li in element.find_all('li', recursive=False):
            list_items.append(self._handle_list_item(li))
        
        # Ensure we have at least one item
        if not list_items:
            list_items = [{
                "type": "listItem",
                "content": [self._create_paragraph("\u00A0")]
            }]
        
        # Create the ordered list node
        list_node = {
            "type": list_type,
            "content": list_items
        }
        
        # If it's an ordered list and has a start attribute, add it to the attrs
        if list_type == "orderedList" and start_attr and start_attr.isdigit():
            list_node["attrs"] = {"start": int(start_attr)}
            logger.info(f"Setting start attribute to {start_attr}")
        
        logger.info(f"Created {list_type} node: {list_node}")
        return list_node
    
    def _handle_list_item(self, element) -> Dict:
        """Convert HTML li to Tiptap listItem node."""
        content = []
        
        # Handle nested lists
        has_text_content = False
        text_content = []
        
        for child in element.children:
            if child.name in ['ul', 'ol']:
                # First add any collected text content as a paragraph
                if text_content:
                    content.append(self._create_paragraph_from_content(text_content))
                    text_content = []
                
                # Then add the nested list
                content.append(self._handle_list(child))
            elif child.name:
                # For other HTML elements, convert normally
                converted = self._convert_element(child)
                if converted:
                    if isinstance(converted, list):
                        content.extend(converted)
                    else:
                        content.append(converted)
            else:
                # Text node
                text = child.string.strip() if child.string else ""
                if text:
                    has_text_content = True
                    text_content.append({"type": "text", "text": text})
        
        # Add any remaining text content
        if text_content:
            content.append(self._create_paragraph_from_content(text_content))
        
        # If no content was found, add an empty paragraph
        if not content:
            content = [self._create_paragraph("\u00A0")]
        
        return {
            "type": "listItem",
            "content": content
        }
    
    def _handle_blockquote(self, element) -> Dict:
        """Convert HTML blockquote to Tiptap blockquote node."""
        content = []
        
        for child in element.children:
            if child.name:
                converted = self._convert_element(child)
                if converted:
                    if isinstance(converted, list):
                        content.extend(converted)
                    else:
                        content.append(converted)
            elif child.string and child.string.strip():
                content.append(self._create_paragraph(child.string.strip()))
        
        # If no content was found, add an empty paragraph
        if not content:
            content = [self._create_paragraph("\u00A0")]
        
        return {
            "type": "blockquote",
            "content": content
        }
    
    def _handle_code_block(self, element) -> Dict:
        """Convert HTML pre to Tiptap codeBlock node."""
        # Extract code and language
        code_element = element.find('code')
        
        if code_element:
            # Try to get language from class (e.g., class="language-python")
            language = ""
            if code_element.get('class'):
                for cls in code_element.get('class'):
                    if cls.startswith('language-'):
                        language = cls[9:]
                        break
            
            code = code_element.get_text()
        else:
            code = element.get_text()
            language = ""
        
        # Replace empty code with non-breaking space
        if not code or code.strip() == "":
            code = "\u00A0"
            
        return {
            "type": "codeBlock",
            "attrs": {
                "language": language
            },
            "content": [
                {
                    "type": "text",
                    "text": code
                }
            ]
        }
    
    def _handle_horizontal_rule(self) -> Dict:
        """Convert HTML hr to Tiptap horizontalRule node."""
        return {
            "type": "horizontalRule"
        }
    
    def _handle_image(self, element) -> Dict:
        """Convert HTML img to Tiptap image node."""
        src = element.get('src', '')
        alt = element.get('alt', '')
        title = element.get('title', '')
        
        return {
            "type": "image",
            "attrs": {
                "src": src,
                "alt": alt,
                "title": title
            }
        }
    
    def _handle_link(self, element) -> List[Dict]:
        """Convert HTML a to Tiptap text with link mark."""
        href = element.get('href', '')
        target = element.get('target', '_self')
        
        # Get the link text
        text = element.get_text() or "\u00A0"
        
        return [{
            "type": "text",
            "marks": [{
                "type": "link",
                "attrs": {
                    "href": href,
                    "target": target
                }
            }],
            "text": text
        }]
    
    def _handle_break(self) -> Dict:
        """Convert HTML br to Tiptap hardBreak node."""
        return {
            "type": "hardBreak"
        }
    
    def _handle_container(self, element) -> List[Dict]:
        """Process container elements by converting their children."""
        result = []
        
        for child in element.children:
            if child.name:
                converted = self._convert_element(child)
                if converted:
                    if isinstance(converted, list):
                        result.extend(converted)
                    else:
                        result.append(converted)
            elif child.string and child.string.strip():
                # Text directly inside div/container should be wrapped in a paragraph
                result.append(self._create_paragraph(child.string.strip()))
        
        return result
    
    def _handle_generic(self, element) -> Union[Dict, List[Dict]]:
        """Generic handler for other HTML elements."""
        # Check if it's an inline element
        inline_elements = ['span', 'strong', 'b', 'em', 'i', 'u', 's', 'code', 'sub', 'sup', 'mark']
        
        if element.name.lower() in inline_elements:
            # Process as inline content
            return self._process_inline_content(element)
        else:
            # Process as a paragraph if it has text content
            text = element.get_text().strip()
            if text:
                return self._create_paragraph(text)
            
            # Or process children if no direct text
            result = []
            for child in element.children:
                if child.name:
                    converted = self._convert_element(child)
                    if converted:
                        if isinstance(converted, list):
                            result.extend(converted)
                        else:
                            result.append(converted)
            
            return result
    
    def _process_inline_content(self, element) -> List[Dict]:
        """Process inline elements and text nodes."""
        result = []
        
        # Process each child node
        for child in element.contents:
            if child.name:
                # Inline elements
                if child.name.lower() in ['strong', 'b']:
                    result.append(self._create_marked_text(child, 'bold'))
                elif child.name.lower() in ['em', 'i']:
                    result.append(self._create_marked_text(child, 'italic'))
                elif child.name.lower() in ['u']:
                    result.append(self._create_marked_text(child, 'underline'))
                elif child.name.lower() in ['s', 'del', 'strike']:
                    result.append(self._create_marked_text(child, 'strike'))
                elif child.name.lower() == 'code':
                    result.append(self._create_marked_text(child, 'code'))
                elif child.name.lower() == 'a':
                    result.extend(self._handle_link(child))
                elif child.name.lower() == 'br':
                    result.append(self._handle_break())
                elif child.name.lower() == 'span':
                    # Spans might have styles that translate to marks
                    style = child.get('style', '')
                    if 'font-weight:bold' in style or 'font-weight: bold' in style:
                        result.append(self._create_marked_text(child, 'bold'))
                    elif 'font-style:italic' in style or 'font-style: italic' in style:
                        result.append(self._create_marked_text(child, 'italic'))
                    elif 'text-decoration:underline' in style or 'text-decoration: underline' in style:
                        result.append(self._create_marked_text(child, 'underline'))
                    elif 'text-decoration:line-through' in style or 'text-decoration: line-through' in style:
                        result.append(self._create_marked_text(child, 'strike'))
                    else:
                        # Just extract the text
                        result.extend(self._process_inline_content(child))
                else:
                    # For other elements, process their inline content
                    result.extend(self._process_inline_content(child))
            elif child.string:
                # Text node
                text = child.string
                if text.strip():  # Only add non-whitespace text
                    result.append({"type": "text", "text": text})
        
        # If no content was found, add an empty text node
        if not result:
            result = [{"type": "text", "text": "\u00A0"}]
        
        return result
    
    def _create_marked_text(self, element, mark_type: str) -> Dict:
        """Create a text node with mark."""
        text = element.get_text() or "\u00A0"
        
        if mark_type == 'link':
            href = element.get('href', '')
            return {
                "type": "text",
                "marks": [{
                    "type": "link",
                    "attrs": {
                        "href": href
                    }
                }],
                "text": text
            }
        else:
            return {
                "type": "text",
                "marks": [{
                    "type": mark_type
                }],
                "text": text
            }
    
    def _create_paragraph(self, text: str) -> Dict:
        """Create a paragraph node with text content."""
        return {
            "type": "paragraph",
            "content": [
                {
                    "type": "text",
                    "text": text or "\u00A0"
                }
            ]
        }
    
    def _create_paragraph_from_content(self, content: List[Dict]) -> Dict:
        """Create a paragraph node with the specified content."""
        return {
            "type": "paragraph",
            "content": content or [{"type": "text", "text": "\u00A0"}]
        }

def convert_html_to_json(html_content: str) -> Dict:
    """Convert HTML to Tiptap JSON format."""
    converter = HTMLToTiptapConverter()
    return converter.convert(html_content)

def main():
    """Main function to run the converter."""
    parser = argparse.ArgumentParser(description='Convert HTML to Tiptap JSON format')
    parser.add_argument('input_file', help='Path to input HTML file')
    parser.add_argument('output_file', help='Path to output Tiptap JSON file')
    args = parser.parse_args()
    
    try:
        # Load the input HTML
        with open(args.input_file, 'r', encoding='utf-8') as f:
            html_content = f.read()
        
        # Convert HTML to Tiptap JSON
        tiptap_json = convert_html_to_json(html_content)
        
        # Save the output
        with open(args.output_file, 'w', encoding='utf-8') as f:
            json.dump(tiptap_json, f, indent=2, ensure_ascii=False)
        
        logger.info(f"Conversion completed! Output saved to {args.output_file}")
        
    except Exception as e:
        logger.error(f"Error during conversion: {e}")
        return 1
    
    return 0

if __name__ == "__main__":
    exit(main())
