#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Final extraction using all templates with proper vertical spacing
"""

import cv2
import numpy as np
from PIL import Image
import pytesseract
import pandas as pd
import re
import os

class FinalVoterExtractor:
    def __init__(self):
        # Load ALL templates
        self.templates = {
            'name': cv2.imread('wider_templates/name_label.png', cv2.IMREAD_GRAYSCALE),
            'voter_id': cv2.imread('wider_templates/voter_id_label.png', cv2.IMREAD_GRAYSCALE),
            'father': cv2.imread('wider_templates/father_label.png', cv2.IMREAD_GRAYSCALE),
            'mother': cv2.imread('wider_templates/mother_label.png', cv2.IMREAD_GRAYSCALE),
            'profession': cv2.imread('wider_templates/profession_label.png', cv2.IMREAD_GRAYSCALE),
            'address': cv2.imread('wider_templates/address_label.png', cv2.IMREAD_GRAYSCALE),
            'dob': cv2.imread('wider_templates/dob_label.png', cv2.IMREAD_GRAYSCALE),
        }

        print("Templates loaded:")
        for name, template in self.templates.items():
            if template is not None:
                print(f"  {name}: {template.shape}")

    def find_template(self, image_gray, template, threshold=0.7):
        """Find template in image"""
        if template is None:
            return None

        result = cv2.matchTemplate(image_gray, template, cv2.TM_CCOEFF_NORMED)
        min_val, max_val, min_loc, max_loc = cv2.minMaxLoc(result)

        if max_val >= threshold:
            h, w = template.shape
            return (*max_loc, w, h, max_val)
        return None

    def extract_text_region(self, image, x, y, w, h):
        """Extract text from region using OCR"""
        # Handle negative y (clip to 0)
        if y < 0:
            h = h + y  # Reduce height
            y = 0

        # Handle out of bounds
        img_h, img_w = image.shape[:2]
        if x < 0: x = 0
        if y < 0: y = 0
        if x + w > img_w: w = img_w - x
        if y + h > img_h: h = img_h - y

        if w <= 0 or h <= 0:
            return ""

        # Crop region
        region = image[y:y+h, x:x+w]
        region_pil = Image.fromarray(region)

        # OCR
        text = pytesseract.image_to_string(region_pil, lang='ben+eng', config=r'--oem 3 --psm 7')
        return text.strip()

    def bengali_to_english_number(self, text):
        """Convert Bengali numerals to English"""
        if not text:
            return text
        trans = str.maketrans('০১২৩৪৫৬৭৮৯', '0123456789')
        return text.translate(trans)

    def clean_text(self, text):
        """Clean extracted text"""
        if not text:
            return text
        text = re.sub(r'[|\\<>{}[\]()"\']', '', text)
        text = re.sub(r'\s+', ' ', text).strip()
        return text

    def extract_cell(self, cell_path):
        """Extract voter data from a single cell"""
        cell_color = cv2.imread(cell_path)
        cell_gray = cv2.cvtColor(cell_color, cv2.COLOR_BGR2GRAY)
        cell_h, cell_w = cell_color.shape[:2]

        voter = {}

        # Find all template positions
        positions = {}
        for template_name, template in self.templates.items():
            match = self.find_template(cell_gray, template)
            if match:
                positions[template_name] = match

        # Extract based on found templates

        # 1. SERIAL NUMBER - left of name label
        if 'name' in positions:
            name_x, name_y, name_w, name_h, name_score = positions['name']

            serial_x = 0
            serial_y = name_y  # Use exact line position
            serial_w = name_x
            serial_h = name_h  # Use template height

            serial_text = self.extract_text_region(cell_color, serial_x, serial_y, serial_w, serial_h)
            serial_match = re.search(r'([০-৯0-9]{1,4})', serial_text)
            if serial_match:
                serial = self.bengali_to_english_number(serial_match.group(1))
                voter['serial_no'] = serial.zfill(4)

        # 2. NAME - right of name label
        if 'name' in positions:
            name_x, name_y, name_w, name_h, name_score = positions['name']

            value_x = name_x + name_w
            value_y = name_y  # Exact position
            value_w = cell_w - value_x - 5
            value_h = name_h  # Use template height

            name_text = self.extract_text_region(cell_color, value_x, value_y, value_w, value_h)
            name_clean = self.clean_text(name_text)
            if len(name_clean) >= 3:
                voter['name'] = name_clean

        # 3. VOTER ID - right of voter_id label
        if 'voter_id' in positions:
            label_x, label_y, label_w, label_h, score = positions['voter_id']

            value_x = label_x + label_w
            value_y = label_y
            value_w = cell_w - value_x - 5
            value_h = label_h

            text = self.extract_text_region(cell_color, value_x, value_y, value_w, value_h)
            vid_clean = re.sub(r'[^\d]', '', self.bengali_to_english_number(text))
            if 10 <= len(vid_clean) <= 14:
                voter['voter_id'] = vid_clean

        # 4. FATHER - right of father label
        if 'father' in positions:
            label_x, label_y, label_w, label_h, score = positions['father']

            value_x = label_x + label_w
            value_y = label_y
            value_w = cell_w - value_x - 5
            value_h = label_h

            text = self.extract_text_region(cell_color, value_x, value_y, value_w, value_h)
            father = self.clean_text(text)
            if len(father) >= 3:
                voter['father_name'] = father

        # 5. MOTHER - right of mother label
        if 'mother' in positions:
            label_x, label_y, label_w, label_h, score = positions['mother']

            value_x = label_x + label_w
            value_y = label_y
            value_w = cell_w - value_x - 5
            value_h = label_h

            text = self.extract_text_region(cell_color, value_x, value_y, value_w, value_h)
            mother = self.clean_text(text)
            if len(mother) >= 3:
                voter['mother_name'] = mother

        # 6. PROFESSION - right of profession label, before DOB label
        if 'profession' in positions:
            prof_x, prof_y, prof_w, prof_h, score = positions['profession']

            # If DOB found, stop before it; otherwise go to end
            if 'dob' in positions:
                dob_x, dob_y, dob_w, dob_h, dob_score = positions['dob']
                value_w = dob_x - (prof_x + prof_w) - 5
            else:
                value_w = cell_w - (prof_x + prof_w) - 5

            value_x = prof_x + prof_w
            value_y = prof_y
            value_h = prof_h

            text = self.extract_text_region(cell_color, value_x, value_y, value_w, value_h)
            profession = self.clean_text(text)
            # Remove common OCR artifacts
            profession = re.sub(r'[,।]', '', profession).strip()
            if len(profession) >= 2:
                voter['profession'] = profession

        # 7. DATE OF BIRTH - right of dob label
        if 'dob' in positions:
            label_x, label_y, label_w, label_h, score = positions['dob']

            value_x = label_x + label_w
            value_y = label_y
            value_w = cell_w - value_x - 5
            value_h = label_h

            text = self.extract_text_region(cell_color, value_x, value_y, value_w, value_h)
            dob_match = re.search(r'([০-৯0-9]{2}[/.]?[০-৯0-9]{2}[/.]?[০-৯0-9]{4})', text)
            if dob_match:
                dob = self.bengali_to_english_number(dob_match.group(1))
                dob = dob.replace('.', '/')
                voter['date_of_birth'] = dob

        # 8. ADDRESS - right of address label (USE FULL REMAINING HEIGHT)
        if 'address' in positions:
            label_x, label_y, label_w, label_h, score = positions['address']

            value_x = label_x + label_w
            value_y = label_y
            value_w = cell_w - value_x - 5
            # USE REMAINING HEIGHT FROM ADDRESS LINE TO BOTTOM OF CELL
            value_h = cell_h - label_y - 5  # Go to bottom of cell

            text = self.extract_text_region(cell_color, value_x, value_y, value_w, value_h)
            address = self.clean_text(text)
            if len(address) >= 5:
                voter['address'] = address

        return voter if voter else None

# Test on first 5 cells
extractor = FinalVoterExtractor()

cell_files = sorted([f for f in os.listdir('page6_cells') if f.startswith('cell_') and f.endswith('.png')])

print(f"\nTesting extraction on first 5 cells...\n")

for cell_file in cell_files[:5]:
    cell_path = f'page6_cells/{cell_file}'
    print(f"{cell_file}:")

    voter = extractor.extract_cell(cell_path)

    if voter:
        for key, value in voter.items():
            print(f"  {key}: {value}")
    else:
        print("  (empty/migrated)")

    print()
