#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Debug OCR issues - visualize regions and test different configurations
"""

import cv2
import numpy as np
from PIL import Image
import pytesseract
import re
import os

# Import the extractor to reuse template matching logic
import sys
from pathlib import Path
sys.path.insert(0, str(Path(__file__).parent))
from workflow.cells_to_csv import CellsToCSVExtractor


def visualize_regions(cell_path, output_path):
    """Draw bounding boxes for all extracted regions"""
    extractor = CellsToCSVExtractor()

    cell_color = cv2.imread(cell_path)
    if cell_color is None:
        print(f"Error: Could not load {cell_path}")
        return

    cell_gray = cv2.cvtColor(cell_color, cv2.COLOR_BGR2GRAY)
    cell_h, cell_w = cell_color.shape[:2]

    # Create a copy for annotation
    annotated = cell_color.copy()

    print(f"\nAnalyzing: {cell_path}")
    print(f"Image size: {cell_w}x{cell_h}")

    # Find all template positions
    positions = {}
    for template_name, template in extractor.templates.items():
        match = extractor.find_template(cell_gray, template)
        if match:
            positions[template_name] = match
            x, y, w, h, score = match
            # Draw template location in blue
            cv2.rectangle(annotated, (x, y), (x+w, y+h), (255, 0, 0), 2)
            cv2.putText(annotated, f"{template_name}", (x, y-5),
                       cv2.FONT_HERSHEY_SIMPLEX, 0.4, (255, 0, 0), 1)

    # Now draw all value extraction regions
    colors = {
        'serial': (0, 255, 0),     # Green
        'name': (0, 255, 255),      # Yellow
        'voter_id': (255, 0, 255),  # Magenta
        'father': (0, 165, 255),    # Orange
        'mother': (255, 255, 0),    # Cyan
        'profession': (128, 0, 128), # Purple
        'dob': (0, 128, 255),       # Orange-red
        'address': (255, 128, 0)    # Sky blue
    }

    # 1. SERIAL NUMBER
    if 'name' in positions:
        name_x, name_y, name_w, name_h, name_score = positions['name']
        serial_x = 0
        serial_y = name_y - extractor.V_PAD_TOP
        serial_w = name_x - 3
        serial_h = name_h + extractor.V_PAD_TOP + extractor.V_PAD_BOTTOM

        if serial_w > 10:
            cv2.rectangle(annotated, (serial_x, serial_y),
                         (serial_x+serial_w, serial_y+serial_h), colors['serial'], 2)
            cv2.putText(annotated, "serial", (serial_x, serial_y-5),
                       cv2.FONT_HERSHEY_SIMPLEX, 0.4, colors['serial'], 1)

            # Extract and print
            text = extractor.extract_text_region(cell_color, serial_x, serial_y, serial_w, serial_h)
            print(f"  SERIAL [{serial_w}x{serial_h}]: '{text}'")

    # 2. NAME
    if 'name' in positions:
        name_x, name_y, name_w, name_h, name_score = positions['name']
        value_x = name_x + name_w + extractor.H_GAP
        value_y = name_y - extractor.V_PAD_TOP
        value_w = cell_w - value_x - 5
        value_h = name_h + extractor.V_PAD_TOP + extractor.V_PAD_BOTTOM

        cv2.rectangle(annotated, (value_x, value_y),
                     (value_x+value_w, value_y+value_h), colors['name'], 2)
        cv2.putText(annotated, "name", (value_x, value_y-5),
                   cv2.FONT_HERSHEY_SIMPLEX, 0.4, colors['name'], 1)

        text = extractor.extract_text_region(cell_color, value_x, value_y, value_w, value_h)
        print(f"  NAME [{value_w}x{value_h}]: '{text}'")

    # 3. VOTER ID
    if 'voter_id' in positions:
        label_x, label_y, label_w, label_h, score = positions['voter_id']
        value_x = label_x + label_w + extractor.H_GAP
        value_y = label_y - extractor.V_PAD_TOP
        value_w = cell_w - value_x - 5
        value_h = label_h + extractor.V_PAD_TOP + extractor.V_PAD_BOTTOM

        cv2.rectangle(annotated, (value_x, value_y),
                     (value_x+value_w, value_y+value_h), colors['voter_id'], 2)
        cv2.putText(annotated, "voter_id", (value_x, value_y-5),
                   cv2.FONT_HERSHEY_SIMPLEX, 0.4, colors['voter_id'], 1)

        text = extractor.extract_text_region(cell_color, value_x, value_y, value_w, value_h)
        print(f"  VOTER_ID [{value_w}x{value_h}]: '{text}'")

    # 4. FATHER
    if 'father' in positions:
        label_x, label_y, label_w, label_h, score = positions['father']
        value_x = label_x + label_w + extractor.H_GAP
        value_y = label_y - extractor.V_PAD_TOP
        value_w = cell_w - value_x - 5
        value_h = label_h + extractor.V_PAD_TOP + extractor.V_PAD_BOTTOM

        cv2.rectangle(annotated, (value_x, value_y),
                     (value_x+value_w, value_y+value_h), colors['father'], 2)
        cv2.putText(annotated, "father", (value_x, value_y-5),
                   cv2.FONT_HERSHEY_SIMPLEX, 0.4, colors['father'], 1)

        text = extractor.extract_text_region(cell_color, value_x, value_y, value_w, value_h)
        print(f"  FATHER [{value_w}x{value_h}]: '{text}'")

    # 5. MOTHER
    if 'mother' in positions:
        label_x, label_y, label_w, label_h, score = positions['mother']
        value_x = label_x + label_w + extractor.H_GAP
        value_y = label_y - extractor.V_PAD_TOP
        value_w = cell_w - value_x - 5
        value_h = label_h + extractor.V_PAD_TOP + extractor.V_PAD_BOTTOM

        cv2.rectangle(annotated, (value_x, value_y),
                     (value_x+value_w, value_y+value_h), colors['mother'], 2)
        cv2.putText(annotated, "mother", (value_x, value_y-5),
                   cv2.FONT_HERSHEY_SIMPLEX, 0.4, colors['mother'], 1)

        text = extractor.extract_text_region(cell_color, value_x, value_y, value_w, value_h)
        print(f"  MOTHER [{value_w}x{value_h}]: '{text}'")

    # 6. PROFESSION
    if 'profession' in positions:
        prof_x, prof_y, prof_w, prof_h, score = positions['profession']

        if 'dob' in positions:
            dob_x, dob_y, dob_w, dob_h, dob_score = positions['dob']
            value_w = dob_x - (prof_x + prof_w) - extractor.H_GAP
        else:
            value_w = cell_w - (prof_x + prof_w) - 5

        value_x = prof_x + prof_w + extractor.H_GAP
        value_y = prof_y - extractor.V_PAD_TOP
        value_h = prof_h + extractor.V_PAD_TOP + extractor.V_PAD_BOTTOM

        cv2.rectangle(annotated, (value_x, value_y),
                     (value_x+value_w, value_y+value_h), colors['profession'], 2)
        cv2.putText(annotated, "profession", (value_x, value_y-5),
                   cv2.FONT_HERSHEY_SIMPLEX, 0.4, colors['profession'], 1)

        text = extractor.extract_text_region(cell_color, value_x, value_y, value_w, value_h)
        print(f"  PROFESSION [{value_w}x{value_h}]: '{text}'")

    # 7. DATE OF BIRTH
    if 'dob' in positions:
        label_x, label_y, label_w, label_h, score = positions['dob']
        value_x = label_x + label_w + extractor.H_GAP
        value_y = label_y - extractor.V_PAD_TOP
        value_w = cell_w - value_x - 5
        value_h = label_h + extractor.V_PAD_TOP + extractor.V_PAD_BOTTOM

        cv2.rectangle(annotated, (value_x, value_y),
                     (value_x+value_w, value_y+value_h), colors['dob'], 2)
        cv2.putText(annotated, "dob", (value_x, value_y-5),
                   cv2.FONT_HERSHEY_SIMPLEX, 0.4, colors['dob'], 1)

        text = extractor.extract_text_region(cell_color, value_x, value_y, value_w, value_h)
        print(f"  DOB [{value_w}x{value_h}]: '{text}'")

    # 8. ADDRESS
    if 'address' in positions:
        label_x, label_y, label_w, label_h, score = positions['address']
        value_x = label_x + label_w + extractor.H_GAP
        value_y = label_y - extractor.V_PAD_TOP
        value_w = cell_w - value_x - 5
        value_h = cell_h - value_y - 5

        cv2.rectangle(annotated, (value_x, value_y),
                     (value_x+value_w, value_y+value_h), colors['address'], 2)
        cv2.putText(annotated, "address", (value_x, value_y-5),
                   cv2.FONT_HERSHEY_SIMPLEX, 0.4, colors['address'], 1)

        text = extractor.extract_text_region(cell_color, value_x, value_y, value_w, value_h)
        print(f"  ADDRESS [{value_w}x{value_h}]: '{text}'")

    # Save annotated image
    cv2.imwrite(output_path, annotated)
    print(f"\nAnnotated image saved to: {output_path}")


def test_preprocessing(cell_path, field_coords, field_name):
    """Test different preprocessing techniques on a specific field"""
    cell_color = cv2.imread(cell_path)
    if cell_color is None:
        return

    x, y, w, h = field_coords

    # Handle bounds
    if y < 0:
        h = h + y
        y = 0
    if x < 0:
        w = w + x
        x = 0

    img_h, img_w = cell_color.shape[:2]
    if x + w > img_w:
        w = img_w - x
    if y + h > img_h:
        h = img_h - y

    if w <= 0 or h <= 0:
        return

    # Extract region
    region = cell_color[y:y+h, x:x+w]

    print(f"\n{'='*60}")
    print(f"Testing preprocessing on {field_name}")
    print(f"{'='*60}")

    # Original
    print("\n1. ORIGINAL (no preprocessing):")
    text = pytesseract.image_to_string(Image.fromarray(region), lang='ben+eng', config=r'--oem 3 --psm 7')
    print(f"   Result: '{text.strip()}'")

    # Grayscale
    print("\n2. GRAYSCALE:")
    gray = cv2.cvtColor(region, cv2.COLOR_BGR2GRAY)
    text = pytesseract.image_to_string(Image.fromarray(gray), lang='ben+eng', config=r'--oem 3 --psm 7')
    print(f"   Result: '{text.strip()}'")

    # Binary threshold
    print("\n3. BINARY THRESHOLD:")
    _, binary = cv2.threshold(gray, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
    text = pytesseract.image_to_string(Image.fromarray(binary), lang='ben+eng', config=r'--oem 3 --psm 7')
    print(f"   Result: '{text.strip()}'")

    # Denoise
    print("\n4. DENOISED:")
    denoised = cv2.fastNlMeansDenoising(gray, None, 10, 7, 21)
    text = pytesseract.image_to_string(Image.fromarray(denoised), lang='ben+eng', config=r'--oem 3 --psm 7')
    print(f"   Result: '{text.strip()}'")

    # Denoise + Binary
    print("\n5. DENOISED + BINARY:")
    _, binary2 = cv2.threshold(denoised, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
    text = pytesseract.image_to_string(Image.fromarray(binary2), lang='ben+eng', config=r'--oem 3 --psm 7')
    print(f"   Result: '{text.strip()}'")

    # Contrast enhancement
    print("\n6. CONTRAST ENHANCED (CLAHE):")
    clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8,8))
    enhanced = clahe.apply(gray)
    text = pytesseract.image_to_string(Image.fromarray(enhanced), lang='ben+eng', config=r'--oem 3 --psm 7')
    print(f"   Result: '{text.strip()}'")

    # Scale up (resize)
    print("\n7. SCALED UP 2x:")
    scaled = cv2.resize(region, None, fx=2, fy=2, interpolation=cv2.INTER_CUBIC)
    text = pytesseract.image_to_string(Image.fromarray(scaled), lang='ben+eng', config=r'--oem 3 --psm 7')
    print(f"   Result: '{text.strip()}'")


def test_psm_modes(cell_path, field_coords, field_name):
    """Test different PSM modes"""
    cell_color = cv2.imread(cell_path)
    if cell_color is None:
        return

    x, y, w, h = field_coords

    # Handle bounds
    if y < 0:
        h = h + y
        y = 0
    if x < 0:
        w = w + x
        x = 0

    img_h, img_w = cell_color.shape[:2]
    if x + w > img_w:
        w = img_w - x
    if y + h > img_h:
        h = img_h - y

    if w <= 0 or h <= 0:
        return

    region = cell_color[y:y+h, x:x+w]
    region_pil = Image.fromarray(region)

    print(f"\n{'='*60}")
    print(f"Testing PSM modes on {field_name}")
    print(f"{'='*60}")

    psm_modes = {
        3: "Fully automatic page segmentation",
        6: "Assume a single uniform block of text",
        7: "Treat image as single text line (CURRENT)",
        8: "Treat image as single word",
        11: "Sparse text - find as much text as possible",
        13: "Raw line - treat as single text line, bypass layout"
    }

    for psm, description in psm_modes.items():
        config = f'--oem 3 --psm {psm}'
        text = pytesseract.image_to_string(region_pil, lang='ben+eng', config=config)
        print(f"\nPSM {psm} - {description}:")
        print(f"   Result: '{text.strip()}'")


if __name__ == '__main__':
    # Analyze problematic cells from page 3
    cell6_path = 'page3_output/cells/page003_cell06.png'
    cell12_path = 'page3_output/cells/page003_cell12.png'

    # 1. Visualize regions
    print("="*80)
    print("STEP 1: VISUALIZE REGION EXTRACTION")
    print("="*80)

    if os.path.exists(cell6_path):
        visualize_regions(cell6_path, 'debug_cell6_regions.png')

    if os.path.exists(cell12_path):
        visualize_regions(cell12_path, 'debug_cell12_regions.png')

    # 2. Test serial number extraction on cell 12 (which failed)
    if os.path.exists(cell12_path):
        extractor = CellsToCSVExtractor()
        cell_color = cv2.imread(cell12_path)
        cell_gray = cv2.cvtColor(cell_color, cv2.COLOR_BGR2GRAY)

        positions = {}
        for template_name, template in extractor.templates.items():
            match = extractor.find_template(cell_gray, template)
            if match:
                positions[template_name] = match

        if 'name' in positions:
            name_x, name_y, name_w, name_h, name_score = positions['name']
            serial_x = 0
            serial_y = name_y - extractor.V_PAD_TOP
            serial_w = name_x - 3
            serial_h = name_h + extractor.V_PAD_TOP + extractor.V_PAD_BOTTOM

            test_preprocessing(cell12_path, (serial_x, serial_y, serial_w, serial_h), "SERIAL NUMBER (Cell 12)")
            test_psm_modes(cell12_path, (serial_x, serial_y, serial_w, serial_h), "SERIAL NUMBER (Cell 12)")
