#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Module: Cells to CSV Extractor
Extracts voter data from cell images using OCR
Supports both Tesseract and Google Cloud Vision API with batch processing
"""

import cv2
import numpy as np
from PIL import Image
import pytesseract
import pandas as pd
import re
import os
import time
from collections import defaultdict


class CellsToCSVExtractor:
    def __init__(self, ocr_engine='tesseract', google_credentials_path=None,
                 batch_size=16, max_requests_per_minute=1800, gemini_api_key=None, gemini_model='gemini-2.5-flash'):
        """
        Initialize the extractor with specified OCR engine

        Args:
            ocr_engine: 'tesseract', 'google', or 'gemini' (default: 'tesseract')
            google_credentials_path: Path to Google Cloud credentials JSON file
            batch_size: Number of images to process per batch (max 16 for Google, default: 16)
            max_requests_per_minute: Rate limit for API calls (default: 1800)
            gemini_api_key: Google Gemini API key (for gemini ocr_engine)
            gemini_model: Gemini model name (default: 'gemini-2.5-flash')
        """
        self.ocr_engine = ocr_engine.lower()
        self.batch_size = min(batch_size, 16)  # Google's limit is 16
        self.max_requests_per_minute = max_requests_per_minute
        self.request_times = []  # Track request timestamps for rate limiting

        # Initialize Google Cloud Vision if needed
        self.vision_client = None
        self.gemini_model = None
        self.gemini_model_name = None  # Store model name string for cost calculation

        if self.ocr_engine == 'google':
            try:
                from google.cloud import vision
                if google_credentials_path:
                    os.environ['GOOGLE_APPLICATION_CREDENTIALS'] = google_credentials_path
                self.vision_client = vision.ImageAnnotatorClient()
                print(f"  Using Google Cloud Vision API for OCR (batch mode: {self.batch_size} images)")
                print(f"  Rate limit: {self.max_requests_per_minute} requests/minute")
            except ImportError:
                print("  Warning: google-cloud-vision not installed. Falling back to Tesseract.")
                print("  Install with: pip install google-cloud-vision")
                self.ocr_engine = 'tesseract'
            except Exception as e:
                print(f"  Warning: Failed to initialize Google Cloud Vision: {e}")
                print("  Falling back to Tesseract.")
                self.ocr_engine = 'tesseract'
        elif self.ocr_engine == 'gemini':
            try:
                import google.generativeai as genai
                if gemini_api_key:
                    genai.configure(api_key=gemini_api_key)
                self.gemini_model = genai.GenerativeModel(gemini_model)
                self.gemini_model_name = gemini_model  # Store model name string
                print(f"  Using Google Gemini model: {gemini_model}")
                print(f"  Rate limit: {self.max_requests_per_minute} requests/minute")
            except ImportError:
                print("  Warning: google-generativeai not installed. Falling back to Tesseract.")
                print("  Install with: pip install google-generativeai")
                self.ocr_engine = 'tesseract'
            except Exception as e:
                print(f"  Warning: Failed to initialize Google Gemini: {e}")
                print("  Falling back to Tesseract.")
                self.ocr_engine = 'tesseract'
        else:
            print(f"  Using Tesseract OCR")

        # Template directory (relative to workflow folder)
        template_dir = os.path.join(os.path.dirname(__file__), '..', 'wider_templates')

        # Load all templates
        self.templates = {
            'name': cv2.imread(f'{template_dir}/name_label.png', cv2.IMREAD_GRAYSCALE),
            'voter_id': cv2.imread(f'{template_dir}/voter_id_label.png', cv2.IMREAD_GRAYSCALE),
            'father': cv2.imread(f'{template_dir}/father_label.png', cv2.IMREAD_GRAYSCALE),
            'mother': cv2.imread(f'{template_dir}/mother_label.png', cv2.IMREAD_GRAYSCALE),
            'profession': cv2.imread(f'{template_dir}/profession_label.png', cv2.IMREAD_GRAYSCALE),
            'address': cv2.imread(f'{template_dir}/address_label.png', cv2.IMREAD_GRAYSCALE),
            'dob': cv2.imread(f'{template_dir}/dob_label.png', cv2.IMREAD_GRAYSCALE),
        }

        # Padding constants
        self.V_PAD_TOP = 3
        self.V_PAD_BOTTOM = 8
        self.H_GAP = 5

    def find_template(self, image_gray, template, threshold=0.7):
        """Find template in image"""
        if template is None:
            return None

        result = cv2.matchTemplate(image_gray, template, cv2.TM_CCOEFF_NORMED)
        min_val, max_val, min_loc, max_loc = cv2.minMaxLoc(result)

        if max_val >= threshold:
            h, w = template.shape
            return (*max_loc, w, h, max_val)
        return None

    def _enhance_for_ocr(self, image):
        """Enhance image for better OCR by reducing watermark and boosting text contrast"""
        # Convert to grayscale
        if len(image.shape) == 3:
            gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
        else:
            gray = image.copy()

        # Increase contrast using CLAHE (Contrast Limited Adaptive Histogram Equalization)
        # This helps text stand out from watermarks without being too aggressive
        clahe = cv2.createCLAHE(clipLimit=3.0, tileGridSize=(8,8))
        enhanced = clahe.apply(gray)

        # Convert back to BGR for consistency
        result = cv2.cvtColor(enhanced, cv2.COLOR_GRAY2BGR)

        return result

    def rate_limit_check(self, num_images=1):
        """
        Check and enforce rate limiting for API calls

        Args:
            num_images: Number of images in this request (each counts toward quota)
        """
        current_time = time.time()

        # Remove requests older than 1 minute (sliding window)
        self.request_times = [t for t in self.request_times if current_time - t < 60]

        # If adding these images would exceed quota, wait until oldest requests age out
        while len(self.request_times) + num_images > self.max_requests_per_minute:
            if self.request_times:
                # Calculate how long to wait for oldest request to age out
                oldest_time = self.request_times[0]
                sleep_time = 60.1 - (current_time - oldest_time)  # Add 0.1s buffer

                if sleep_time > 0:
                    print(f"    Rate limit: {len(self.request_times)}/{self.max_requests_per_minute} used. Waiting {sleep_time:.1f}s...")
                    time.sleep(sleep_time)
                    current_time = time.time()
                    # Remove requests older than 1 minute after sleep
                    self.request_times = [t for t in self.request_times if current_time - t < 60]
                else:
                    # Remove the oldest entry and check again
                    self.request_times.pop(0)
            else:
                # No requests in window but somehow still over limit - shouldn't happen
                break

        # Add a timestamp for each image in the batch (Google counts each image)
        for _ in range(num_images):
            self.request_times.append(current_time)

    def batch_ocr_google(self, image_regions):
        """
        Process multiple image regions in a single batch request

        Args:
            image_regions: List of tuples (image_data, region_id)

        Returns:
            Dict mapping region_id to extracted text
        """
        if not self.vision_client:
            return {}

        from google.cloud import vision

        results = {}
        total_batches = (len(image_regions) + self.batch_size - 1) // self.batch_size

        # Process in batches
        for batch_idx, batch_start in enumerate(range(0, len(image_regions), self.batch_size), 1):
            batch_end = min(batch_start + self.batch_size, len(image_regions))
            batch = image_regions[batch_start:batch_end]

            # Rate limiting - pass number of images (Google counts each image toward quota)
            self.rate_limit_check(num_images=len(batch))

            print(f"    Processing batch {batch_idx}/{total_batches} ({len(batch)} images, {len(self.request_times)} in quota window)...")

            # Build batch request
            requests = []
            for image_data, region_id in batch:
                # Preprocess to enhance text visibility and reduce watermark interference
                processed_image = self._enhance_for_ocr(image_data)

                # Encode image
                _, encoded_image = cv2.imencode('.png', processed_image)
                content = encoded_image.tobytes()

                # Create image request
                image = vision.Image(content=content)
                requests.append({
                    'image': image,
                    'features': [{'type_': vision.Feature.Type.TEXT_DETECTION}]
                })

            try:
                # Batch annotate - correct API format
                response = self.vision_client.batch_annotate_images(requests=requests)

                # Extract results
                for idx, image_response in enumerate(response.responses):
                    region_id = batch[idx][1]

                    if image_response.error.message:
                        results[region_id] = ""
                    elif image_response.text_annotations:
                        results[region_id] = image_response.text_annotations[0].description.strip()
                    else:
                        results[region_id] = ""

            except Exception as e:
                print(f"    Batch OCR error: {e}")
                # Fill with empty strings for failed batch
                for _, region_id in batch:
                    results[region_id] = ""

        return results

    def extract_text_region_google(self, image, x, y, w, h):
        """Extract text using Google Cloud Vision API

        Args:
            image: Source image
            x, y, w, h: Region coordinates
        """
        # Handle negative coordinates and bounds
        if y < 0:
            h = h + y
            y = 0
        if x < 0:
            w = w + x
            x = 0

        img_h, img_w = image.shape[:2]
        if x + w > img_w:
            w = img_w - x
        if y + h > img_h:
            h = img_h - y

        if w <= 0 or h <= 0:
            return ""

        # Crop region
        region = image[y:y+h, x:x+w]

        try:
            # Encode image to bytes
            from google.cloud import vision
            _, encoded_image = cv2.imencode('.png', region)
            content = encoded_image.tobytes()

            # Create Vision API image
            vision_image = vision.Image(content=content)

            # Perform text detection
            response = self.vision_client.text_detection(image=vision_image)

            if response.error.message:
                print(f"    Google Vision API error: {response.error.message}")
                return ""

            texts = response.text_annotations
            if texts:
                # First annotation contains all text
                return texts[0].description.strip()
            return ""

        except Exception as e:
            print(f"    Google Vision API exception: {e}")
            return ""

    def extract_text_region(self, image, x, y, w, h, use_binary=False):
        """Extract text from region using configured OCR engine

        Args:
            image: Source image
            x, y, w, h: Region coordinates
            use_binary: If True, apply binary threshold (only for Tesseract)
        """
        # Use Google Cloud Vision if configured
        if self.ocr_engine == 'google' and self.vision_client:
            return self.extract_text_region_google(image, x, y, w, h)

        # Otherwise use Tesseract with preprocessing
        # Handle negative coordinates
        if y < 0:
            h = h + y
            y = 0
        if x < 0:
            w = w + x
            x = 0

        # Handle out of bounds
        img_h, img_w = image.shape[:2]
        if x + w > img_w:
            w = img_w - x
        if y + h > img_h:
            h = img_h - y

        if w <= 0 or h <= 0:
            return ""

        # Crop region
        region = image[y:y+h, x:x+w]

        # Preprocessing based on field type (Tesseract only)
        if len(region.shape) == 3:
            gray = cv2.cvtColor(region, cv2.COLOR_BGR2GRAY)
        else:
            gray = region

        if use_binary:
            # Binary thresholding: Good for numeric fields (serial, voter_id, dob)
            # Helps with faint or low-contrast numbers
            _, processed = cv2.threshold(gray, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
        else:
            # Grayscale only: Better for Bengali text fields
            # Binary threshold can distort complex Bengali characters
            processed = gray

        # Convert to PIL Image
        region_pil = Image.fromarray(processed)

        # OCR with Bengali + English
        text = pytesseract.image_to_string(region_pil, lang='ben+eng', config=r'--oem 3 --psm 7')
        return text.strip()

    def bengali_to_english_number(self, text):
        """Convert Bengali numerals to English"""
        if not text:
            return text
        trans = str.maketrans('০১২৩৪৫৬৭৮৯', '0123456789')
        return text.translate(trans)

    def clean_text(self, text):
        """Clean extracted text"""
        if not text:
            return text
        text = re.sub(r'[|\\<>{}[\]()"\']', '', text)
        text = re.sub(r'\s+', ' ', text).strip()
        return text

    def extract_regions_for_batch(self, cell_path, page_num, cell_num):
        """
        Prepare regions from a cell for batch OCR processing

        Returns:
            Tuple of (regions_list, positions, cell_color, cell_dimensions)
            where regions_list contains (image_data, region_id) tuples
        """
        cell_color = cv2.imread(cell_path)
        if cell_color is None:
            return None, None, None, None

        cell_gray = cv2.cvtColor(cell_color, cv2.COLOR_BGR2GRAY)
        cell_h, cell_w = cell_color.shape[:2]

        # Find all template positions
        positions = {}
        for template_name, template in self.templates.items():
            match = self.find_template(cell_gray, template)
            if match:
                positions[template_name] = match

        # If no templates found, likely a migrated cell
        if len(positions) == 0:
            return None, None, None, None

        regions = []

        # Extract all regions that need OCR
        # 1. SERIAL NUMBER
        if 'name' in positions:
            name_x, name_y, name_w, name_h, _ = positions['name']
            serial_x, serial_y = 0, name_y - self.V_PAD_TOP
            serial_w = name_x - 3
            serial_h = name_h + self.V_PAD_TOP + self.V_PAD_BOTTOM

            if serial_w > 10 and serial_y >= 0:
                region = self._crop_region(cell_color, serial_x, serial_y, serial_w, serial_h)
                if region is not None:
                    region_id = f"p{page_num}_c{cell_num}_serial"
                    regions.append((region, region_id))

        # 2. NAME
        if 'name' in positions:
            name_x, name_y, name_w, name_h, _ = positions['name']
            value_x = name_x + name_w + self.H_GAP
            value_y = name_y - self.V_PAD_TOP
            value_w = cell_w - value_x - 5
            value_h = name_h + self.V_PAD_TOP + self.V_PAD_BOTTOM

            region = self._crop_region(cell_color, value_x, value_y, value_w, value_h, trim_right_whitespace=True)
            if region is not None:
                region_id = f"p{page_num}_c{cell_num}_name"
                regions.append((region, region_id))

        # 3. VOTER ID
        if 'voter_id' in positions:
            label_x, label_y, label_w, label_h, _ = positions['voter_id']
            value_x = label_x + label_w + self.H_GAP
            # Add extra vertical padding for better OCR (Google Vision needs good margins)
            value_y = label_y - self.V_PAD_TOP - 10
            value_w = cell_w - value_x - 5
            value_h = label_h + self.V_PAD_TOP + self.V_PAD_BOTTOM + 20

            region = self._crop_region(cell_color, value_x, value_y, value_w, value_h, trim_right_whitespace=True)
            if region is not None:
                region_id = f"p{page_num}_c{cell_num}_voter_id"
                regions.append((region, region_id))

        # 4. FATHER
        if 'father' in positions:
            label_x, label_y, label_w, label_h, _ = positions['father']
            value_x = label_x + label_w + self.H_GAP
            value_y = label_y - self.V_PAD_TOP
            value_w = cell_w - value_x - 5
            value_h = label_h + self.V_PAD_TOP + self.V_PAD_BOTTOM

            region = self._crop_region(cell_color, value_x, value_y, value_w, value_h, trim_right_whitespace=True)
            if region is not None:
                region_id = f"p{page_num}_c{cell_num}_father"
                regions.append((region, region_id))

        # 5. MOTHER
        if 'mother' in positions:
            label_x, label_y, label_w, label_h, _ = positions['mother']
            value_x = label_x + label_w + self.H_GAP
            value_y = label_y - self.V_PAD_TOP
            value_w = cell_w - value_x - 5
            value_h = label_h + self.V_PAD_TOP + self.V_PAD_BOTTOM

            region = self._crop_region(cell_color, value_x, value_y, value_w, value_h, trim_right_whitespace=True)
            if region is not None:
                region_id = f"p{page_num}_c{cell_num}_mother"
                regions.append((region, region_id))

        # 6. PROFESSION
        if 'profession' in positions:
            prof_x, prof_y, prof_w, prof_h, _ = positions['profession']

            if 'dob' in positions:
                dob_x, _, _, _, _ = positions['dob']
                value_w = dob_x - (prof_x + prof_w) - self.H_GAP
            else:
                value_w = cell_w - (prof_x + prof_w) - 5

            value_x = prof_x + prof_w + self.H_GAP
            value_y = prof_y - self.V_PAD_TOP
            value_h = prof_h + self.V_PAD_TOP + self.V_PAD_BOTTOM

            region = self._crop_region(cell_color, value_x, value_y, value_w, value_h, trim_right_whitespace=True)
            if region is not None:
                region_id = f"p{page_num}_c{cell_num}_profession"
                regions.append((region, region_id))

        # 7. DATE OF BIRTH
        if 'dob' in positions:
            label_x, label_y, label_w, label_h, _ = positions['dob']
            value_x = label_x + label_w + self.H_GAP
            value_y = label_y - self.V_PAD_TOP
            value_w = cell_w - value_x - 5
            value_h = label_h + self.V_PAD_TOP + self.V_PAD_BOTTOM

            region = self._crop_region(cell_color, value_x, value_y, value_w, value_h, trim_right_whitespace=True)
            if region is not None:
                region_id = f"p{page_num}_c{cell_num}_dob"
                regions.append((region, region_id))

        # 8. ADDRESS
        if 'address' in positions:
            label_x, label_y, label_w, label_h, _ = positions['address']
            value_x = label_x + label_w + self.H_GAP
            value_y = label_y - self.V_PAD_TOP
            value_w = cell_w - value_x - 5
            value_h = cell_h - value_y - 5

            region = self._crop_region(cell_color, value_x, value_y, value_w, value_h)
            if region is not None:
                region_id = f"p{page_num}_c{cell_num}_address"
                regions.append((region, region_id))

        return regions, positions, cell_color, (cell_h, cell_w)

    def _crop_region(self, image, x, y, w, h, trim_right_whitespace=False):
        """Safely crop a region from an image

        Args:
            image: Source image
            x, y, w, h: Region coordinates
            trim_right_whitespace: If True, crop white space from right side until black text found
        """
        # Handle negative coordinates
        if y < 0:
            h = h + y
            y = 0
        if x < 0:
            w = w + x
            x = 0

        # Handle out of bounds
        img_h, img_w = image.shape[:2]
        if x + w > img_w:
            w = img_w - x
        if y + h > img_h:
            h = img_h - y

        if w <= 0 or h <= 0:
            return None

        # Crop the region
        region = image[y:y+h, x:x+w]

        # Trim white space from right if requested
        if trim_right_whitespace:
            # Convert to grayscale if needed
            if len(region.shape) == 3:
                gray = cv2.cvtColor(region, cv2.COLOR_BGR2GRAY)
            else:
                gray = region

            # Find rightmost column with text (not solid borders)
            # Dark pixel threshold: < 240 (white is 255)
            dark_threshold = 240

            # Scan from right to left, skip last 5 pixels (usually borders)
            rightmost_text = 0
            for col in range(gray.shape[1] - 6, -1, -1):
                column_pixels = gray[:, col]
                dark_count = np.sum(column_pixels < dark_threshold)

                # Text should have some dark pixels but not ALL pixels black (not a solid line)
                # Text typically covers 20-80% of column height
                if dark_count > 0 and dark_count < (len(column_pixels) * 0.9):
                    rightmost_text = col
                    break

            # Crop to rightmost text position + margin
            if rightmost_text > 0:
                margin = 15  # Small margin after text
                new_width = min(rightmost_text + margin, region.shape[1])
                region = region[:, :new_width]

        return region

    def batch_ocr_full_cells(self, cell_images):
        """
        Process full cell images in batches with Google Cloud Vision

        Args:
            cell_images: List of tuples (cell_image, page_num, cell_num, cell_path)

        Returns:
            Dict mapping (page_num, cell_num) to full OCR text
        """
        if not self.vision_client:
            return {}

        from google.cloud import vision

        results = {}
        total_batches = (len(cell_images) + self.batch_size - 1) // self.batch_size

        # Process in batches
        for batch_idx, batch_start in enumerate(range(0, len(cell_images), self.batch_size), 1):
            batch_end = min(batch_start + self.batch_size, len(cell_images))
            batch = cell_images[batch_start:batch_end]

            # Rate limiting
            self.rate_limit_check(num_images=len(batch))

            print(f"    Processing batch {batch_idx}/{total_batches} ({len(batch)} cells, {len(self.request_times)} in quota window)...")

            # Build batch request
            requests = []
            for cell_image, page_num, cell_num, cell_path in batch:
                # Encode image
                _, encoded_image = cv2.imencode('.png', cell_image)
                content = encoded_image.tobytes()

                # Create image request
                image = vision.Image(content=content)
                requests.append({
                    'image': image,
                    'features': [{'type_': vision.Feature.Type.TEXT_DETECTION}]
                })

            try:
                # Batch annotate
                response = self.vision_client.batch_annotate_images(requests=requests)

                # Extract results
                for idx, image_response in enumerate(response.responses):
                    cell_image, page_num, cell_num, cell_path = batch[idx]

                    if image_response.error.message:
                        results[(page_num, cell_num)] = ""
                    elif image_response.text_annotations:
                        # Store full text
                        full_text = image_response.text_annotations[0].description.strip()
                        results[(page_num, cell_num)] = full_text
                    else:
                        results[(page_num, cell_num)] = ""

            except Exception as e:
                print(f"    Batch OCR error: {e}")
                # Fill with empty strings for failed batch
                for cell_image, page_num, cell_num, cell_path in batch:
                    results[(page_num, cell_num)] = ""

        return results

    def batch_ocr_full_cells_gemini(self, cell_images):
        """
        Process full cell images with Google Gemini 2.5 Flash-Lite

        Args:
            cell_images: List of tuples (cell_image, page_num, cell_num, cell_path)

        Returns:
            Dict mapping (page_num, cell_num) to full OCR text
        """
        if not self.gemini_model:
            return {}

        results = {}
        total_images = len(cell_images)

        # Gemini prompt for Bengali OCR
        prompt = """Extract all text from this Bengali voter information card image with MAXIMUM ACCURACY.

CRITICAL: Pay special attention to Bengali numerals (০১২৩৪৫৬৭৮৯) which appear in:
- Serial numbers (example: ০০০১, ০০১৫)
- Voter ID numbers (example: ৬৮০৩২৩১৬৯৬৫১)
- Dates (example: ০১/০১/১৯৮০)

The image contains voter details in Bengali language with fields like:
- Serial number (4 digits with dot)
- নাম (Name)
- ভোটার নং (Voter ID Number) - VERY IMPORTANT: 12-13 digit number
- পিতা (Father's Name)
- মাতা (Mother's Name)
- পেশা (Profession)
- জন্ম তারিখ (Date of Birth)
- ঠিকানা (Address)

Extract ALL text EXACTLY as it appears in Bengali script, maintaining original spelling and numbers.
Return ONLY the raw text without any translations, formatting, or added labels."""

        # Process each image individually (Gemini doesn't support batch like Cloud Vision)
        for idx, (cell_image, page_num, cell_num, cell_path) in enumerate(cell_images, 1):
            # Rate limiting - 1 request per image
            self.rate_limit_check(num_images=1)

            if idx % 10 == 0 or idx == total_images:
                print(f"    Processing image {idx}/{total_images} ({len(self.request_times)} in quota window)...")

            try:
                # Convert CV2 image to PIL Image
                # CV2 uses BGR, PIL uses RGB
                if len(cell_image.shape) == 3:
                    rgb_image = cv2.cvtColor(cell_image, cv2.COLOR_BGR2RGB)
                else:
                    rgb_image = cv2.cvtColor(cell_image, cv2.COLOR_GRAY2RGB)
                pil_image = Image.fromarray(rgb_image)

                # Generate response
                response = self.gemini_model.generate_content([prompt, pil_image])

                # Extract text
                if response and response.text:
                    full_text = response.text.strip()
                    results[(page_num, cell_num)] = full_text
                else:
                    results[(page_num, cell_num)] = ""

            except Exception as e:
                print(f"    Gemini OCR error for page {page_num} cell {cell_num}: {e}")
                results[(page_num, cell_num)] = ""

        return results

    def parse_full_cell_ocr(self, ocr_results, page_num, cell_num):
        """Parse full cell OCR text into voter record"""
        full_text = ocr_results.get((page_num, cell_num), "")
        if not full_text:
            return None

        voter = {
            'page': page_num,
            'cell': cell_num
        }

        # Split into lines
        lines = full_text.split('\n')

        # Extract fields by parsing the text
        for line in lines:
            line = line.strip()

            # Serial number (starts with digits and dot)
            if re.match(r'^[০-৯0-9।|]+\.', line):
                serial_match = re.search(r'([০-৯0-9]{1,4})\.', line)
                if serial_match:
                    serial = self.bengali_to_english_number(serial_match.group(1))
                    voter['serial_no'] = serial.zfill(4)

            # Name
            if 'নাম:' in line or 'নাম :' in line:
                name_text = re.sub(r'.*?নাম\s*:\s*', '', line)
                name_clean = self.clean_text(name_text)
                if len(name_clean) >= 3:
                    voter['name'] = name_clean

            # Voter ID
            if 'ভোটার নং:' in line or 'ভোটার নং :' in line:
                text = re.sub(r'.*?ভোটার নং\s*:\s*', '', line)
                vid_clean = re.sub(r'[^\d]', '', self.bengali_to_english_number(text))
                if 10 <= len(vid_clean) <= 14:
                    voter['voter_id'] = vid_clean

            # Father
            if 'পিতা:' in line or 'পিতা :' in line:
                father = re.sub(r'.*?পিতা\s*:\s*', '', line)
                father = self.clean_text(father)
                if len(father) >= 3:
                    voter['father_name'] = father

            # Mother
            if 'মাতা:' in line or 'মাতা :' in line:
                mother = re.sub(r'.*?মাতা\s*:\s*', '', line)
                mother = self.clean_text(mother)
                if len(mother) >= 3:
                    voter['mother_name'] = mother

            # Profession and DOB (often on same line)
            if 'পেশা:' in line or 'পেশা :' in line:
                # Extract profession
                prof_text = re.sub(r'.*?পেশা\s*:\s*', '', line)
                # Split at জন্ম তারিখ if exists
                if 'জন্ম তারিখ' in prof_text:
                    parts = prof_text.split('জন্ম তারিখ')
                    profession = self.clean_text(parts[0])
                    profession = re.sub(r'[,।]', '', profession).strip()
                    if len(profession) >= 2:
                        voter['profession'] = profession

                    # Extract DOB
                    if len(parts) > 1:
                        dob_text = parts[1].replace(':', '').strip()
                        dob_match = re.search(r'([০-৯0-9]{2}[/.]?[০-৯0-9]{2}[/.]?[০-৯0-9]{4})', dob_text)
                        if dob_match:
                            dob = self.bengali_to_english_number(dob_match.group(1))
                            dob = dob.replace('.', '/')
                            voter['date_of_birth'] = dob
                else:
                    profession = self.clean_text(prof_text)
                    profession = re.sub(r'[,।]', '', profession).strip()
                    if len(profession) >= 2:
                        voter['profession'] = profession

            # Address
            if 'ঠিকানা:' in line or 'ঠিকানা :' in line:
                address = re.sub(r'.*?ঠিকানা\s*:\s*', '', line)
                address = self.clean_text(address)
                if len(address) >= 5:
                    voter['address'] = address

        return voter if len(voter) > 2 else None

    def parse_batch_results(self, ocr_results, page_num, cell_num):
        """Parse batch OCR results into voter record"""
        voter = {
            'page': page_num,
            'cell': cell_num
        }

        # 1. SERIAL NUMBER
        serial_key = f"p{page_num}_c{cell_num}_serial"
        if serial_key in ocr_results:
            serial_text = ocr_results[serial_key]
            serial_match = re.search(r'([০-৯0-9]{1,4})', serial_text)
            if serial_match:
                serial = self.bengali_to_english_number(serial_match.group(1))
                voter['serial_no'] = serial.zfill(4)

        # 2. NAME
        name_key = f"p{page_num}_c{cell_num}_name"
        if name_key in ocr_results:
            name_text = self.clean_text(ocr_results[name_key])
            if len(name_text) >= 3:
                voter['name'] = name_text

        # 3. VOTER ID
        vid_key = f"p{page_num}_c{cell_num}_voter_id"
        if vid_key in ocr_results:
            text = ocr_results[vid_key]
            vid_clean = re.sub(r'[^\d]', '', self.bengali_to_english_number(text))
            if 10 <= len(vid_clean) <= 14:
                voter['voter_id'] = vid_clean

        # 4. FATHER
        father_key = f"p{page_num}_c{cell_num}_father"
        if father_key in ocr_results:
            father = self.clean_text(ocr_results[father_key])
            if len(father) >= 3:
                voter['father_name'] = father

        # 5. MOTHER
        mother_key = f"p{page_num}_c{cell_num}_mother"
        if mother_key in ocr_results:
            mother = self.clean_text(ocr_results[mother_key])
            if len(mother) >= 3:
                voter['mother_name'] = mother

        # 6. PROFESSION
        prof_key = f"p{page_num}_c{cell_num}_profession"
        if prof_key in ocr_results:
            profession = self.clean_text(ocr_results[prof_key])
            profession = re.sub(r'[,।]', '', profession).strip()
            if len(profession) >= 2:
                voter['profession'] = profession

        # 7. DATE OF BIRTH
        dob_key = f"p{page_num}_c{cell_num}_dob"
        if dob_key in ocr_results:
            text = ocr_results[dob_key]
            dob_match = re.search(r'([০-৯0-9]{2}[/.]?[০-৯0-9]{2}[/.]?[০-৯0-9]{4})', text)
            if dob_match:
                dob = self.bengali_to_english_number(dob_match.group(1))
                dob = dob.replace('.', '/')
                voter['date_of_birth'] = dob

        # 8. ADDRESS
        addr_key = f"p{page_num}_c{cell_num}_address"
        if addr_key in ocr_results:
            address = self.clean_text(ocr_results[addr_key])
            if len(address) >= 5:
                voter['address'] = address

        return voter if len(voter) > 2 else None

    def extract_cell(self, cell_path, page_num, cell_num):
        """Extract voter data from a single cell image"""
        cell_color = cv2.imread(cell_path)
        if cell_color is None:
            return None

        cell_gray = cv2.cvtColor(cell_color, cv2.COLOR_BGR2GRAY)
        cell_h, cell_w = cell_color.shape[:2]

        voter = {
            'page': page_num,
            'cell': cell_num
        }

        # Find all template positions
        positions = {}
        for template_name, template in self.templates.items():
            match = self.find_template(cell_gray, template)
            if match:
                positions[template_name] = match

        # If no templates found, likely a migrated cell
        if len(positions) == 0:
            return None

        # 1. SERIAL NUMBER (use binary threshold for better number recognition)
        if 'name' in positions:
            name_x, name_y, name_w, name_h, name_score = positions['name']
            serial_x = 0
            serial_y = name_y - self.V_PAD_TOP
            serial_w = name_x - 3
            serial_h = name_h + self.V_PAD_TOP + self.V_PAD_BOTTOM

            if serial_w > 10:
                serial_text = self.extract_text_region(cell_color, serial_x, serial_y, serial_w, serial_h, use_binary=True)
                serial_match = re.search(r'([০-৯0-9]{1,4})', serial_text)
                if serial_match:
                    serial = self.bengali_to_english_number(serial_match.group(1))
                    voter['serial_no'] = serial.zfill(4)

        # 2. NAME
        if 'name' in positions:
            name_x, name_y, name_w, name_h, name_score = positions['name']
            value_x = name_x + name_w + self.H_GAP
            value_y = name_y - self.V_PAD_TOP
            value_w = cell_w - value_x - 5
            value_h = name_h + self.V_PAD_TOP + self.V_PAD_BOTTOM

            name_text = self.extract_text_region(cell_color, value_x, value_y, value_w, value_h)
            name_clean = self.clean_text(name_text)
            if len(name_clean) >= 3:
                voter['name'] = name_clean

        # 3. VOTER ID (use binary threshold for better number recognition)
        if 'voter_id' in positions:
            label_x, label_y, label_w, label_h, score = positions['voter_id']
            value_x = label_x + label_w + self.H_GAP
            value_y = label_y - self.V_PAD_TOP
            value_w = cell_w - value_x - 5
            value_h = label_h + self.V_PAD_TOP + self.V_PAD_BOTTOM

            text = self.extract_text_region(cell_color, value_x, value_y, value_w, value_h, use_binary=True)
            vid_clean = re.sub(r'[^\d]', '', self.bengali_to_english_number(text))
            if 10 <= len(vid_clean) <= 14:
                voter['voter_id'] = vid_clean

        # 4. FATHER
        if 'father' in positions:
            label_x, label_y, label_w, label_h, score = positions['father']
            value_x = label_x + label_w + self.H_GAP
            value_y = label_y - self.V_PAD_TOP
            value_w = cell_w - value_x - 5
            value_h = label_h + self.V_PAD_TOP + self.V_PAD_BOTTOM

            text = self.extract_text_region(cell_color, value_x, value_y, value_w, value_h)
            father = self.clean_text(text)
            if len(father) >= 3:
                voter['father_name'] = father

        # 5. MOTHER
        if 'mother' in positions:
            label_x, label_y, label_w, label_h, score = positions['mother']
            value_x = label_x + label_w + self.H_GAP
            value_y = label_y - self.V_PAD_TOP
            value_w = cell_w - value_x - 5
            value_h = label_h + self.V_PAD_TOP + self.V_PAD_BOTTOM

            text = self.extract_text_region(cell_color, value_x, value_y, value_w, value_h)
            mother = self.clean_text(text)
            if len(mother) >= 3:
                voter['mother_name'] = mother

        # 6. PROFESSION
        if 'profession' in positions:
            prof_x, prof_y, prof_w, prof_h, score = positions['profession']

            if 'dob' in positions:
                dob_x, dob_y, dob_w, dob_h, dob_score = positions['dob']
                value_w = dob_x - (prof_x + prof_w) - self.H_GAP
            else:
                value_w = cell_w - (prof_x + prof_w) - 5

            value_x = prof_x + prof_w + self.H_GAP
            value_y = prof_y - self.V_PAD_TOP
            value_h = prof_h + self.V_PAD_TOP + self.V_PAD_BOTTOM

            text = self.extract_text_region(cell_color, value_x, value_y, value_w, value_h)
            profession = self.clean_text(text)
            profession = re.sub(r'[,।]', '', profession).strip()
            if len(profession) >= 2:
                voter['profession'] = profession

        # 7. DATE OF BIRTH (use binary threshold for better number recognition)
        if 'dob' in positions:
            label_x, label_y, label_w, label_h, score = positions['dob']
            value_x = label_x + label_w + self.H_GAP
            value_y = label_y - self.V_PAD_TOP
            value_w = cell_w - value_x - 5
            value_h = label_h + self.V_PAD_TOP + self.V_PAD_BOTTOM

            text = self.extract_text_region(cell_color, value_x, value_y, value_w, value_h, use_binary=True)
            dob_match = re.search(r'([০-৯0-9]{2}[/.]?[০-৯0-9]{2}[/.]?[০-৯0-9]{4})', text)
            if dob_match:
                dob = self.bengali_to_english_number(dob_match.group(1))
                dob = dob.replace('.', '/')
                voter['date_of_birth'] = dob

        # 8. ADDRESS
        if 'address' in positions:
            label_x, label_y, label_w, label_h, score = positions['address']
            value_x = label_x + label_w + self.H_GAP
            value_y = label_y - self.V_PAD_TOP
            value_w = cell_w - value_x - 5
            value_h = cell_h - value_y - 5

            text = self.extract_text_region(cell_color, value_x, value_y, value_w, value_h)
            address = self.clean_text(text)
            if len(address) >= 5:
                voter['address'] = address

        return voter if len(voter) > 2 else None  # More than just page/cell

    def extract_all_cells(self, cells_dir, output_csv):
        """Extract data from all cell images in directory with batch processing"""
        # Get all cell image files
        cell_files = sorted([
            f for f in os.listdir(cells_dir)
            if f.endswith('.png') and 'annotated' not in f.lower() and 'boxes' not in f.lower()
        ])

        all_voters = []

        # Use batch processing for Google Cloud Vision or Gemini
        if (self.ocr_engine == 'google' and self.vision_client) or (self.ocr_engine == 'gemini' and self.gemini_model):
            if self.ocr_engine == 'gemini':
                print(f"\n  Processing {len(cell_files)} cells with Gemini (full cell OCR)...")
            else:
                print(f"\n  Batch processing {len(cell_files)} cells (full cell OCR for 100% accuracy)...")

            # Collect full cell images instead of cropped regions
            cell_images = []
            cell_metadata = []

            print(f"  Loading {len(cell_files)} cell images...")
            for idx, cell_file in enumerate(cell_files, 1):
                cell_path = os.path.join(cells_dir, cell_file)

                # Parse page and cell number from filename
                parts = cell_file.replace('.png', '').split('_')
                page_num = int(parts[0].replace('page', ''))
                cell_num = int(parts[1].replace('cell', ''))

                # Load cell image
                cell_image = cv2.imread(cell_path)
                if cell_image is not None:
                    cell_images.append((cell_image, page_num, cell_num, cell_path))
                    cell_metadata.append((page_num, cell_num))

            print(f"  Loaded {len(cell_images)} cell images")

            # Batch OCR full cells
            if self.ocr_engine == 'gemini':
                print(f"  Processing {len(cell_images)} cells with Gemini...")
            else:
                print(f"  Processing {len(cell_images)} cells in batches of {self.batch_size}...")
            start_time = time.time()

            if self.ocr_engine == 'gemini':
                cell_ocr_results = self.batch_ocr_full_cells_gemini(cell_images)
            else:
                cell_ocr_results = self.batch_ocr_full_cells(cell_images)

            elapsed = time.time() - start_time
            print(f"  OCR completed in {elapsed:.1f}s ({len(cell_images)/elapsed:.1f} cells/sec)")

            # Calculate API costs
            num_images = len(cell_images)

            if self.ocr_engine == 'gemini':
                # Gemini pricing varies by model
                # Estimate ~1000 tokens per image input, ~200 tokens per output
                input_tokens = num_images * 1000
                output_tokens = num_images * 200

                # Determine pricing based on model
                if 'flash-lite' in self.gemini_model_name.lower():
                    input_cost = 0.10
                    output_cost = 0.40
                    pricing_info = f"{self.gemini_model_name}: $0.10/1M input + $0.40/1M output tokens"
                elif 'flash' in self.gemini_model_name.lower():
                    input_cost = 0.15
                    output_cost = 0.60
                    pricing_info = f"{self.gemini_model_name}: $0.15/1M input + $0.60/1M output tokens"
                elif 'pro' in self.gemini_model_name.lower():
                    input_cost = 1.25
                    output_cost = 5.00
                    pricing_info = f"{self.gemini_model_name}: $1.25/1M input + $5.00/1M output tokens"
                else:
                    # Default to Flash pricing
                    input_cost = 0.15
                    output_cost = 0.60
                    pricing_info = f"{self.gemini_model_name}: $0.15/1M input + $0.60/1M output tokens (estimated)"

                cost = (input_tokens / 1_000_000) * input_cost + (output_tokens / 1_000_000) * output_cost
                api_calls_info = f"{num_images} individual requests"
            else:
                # Google Cloud Vision pricing: $1.50/1000 for first 1000, then $0.60/1000
                if num_images <= 1000:
                    cost = (num_images / 1000) * 1.50
                else:
                    cost = 1.50 + ((num_images - 1000) / 1000) * 0.60
                pricing_info = "Google Cloud Vision: $1.50/1000 images for first 1000, then $0.60/1000"
                api_calls_info = f"{(num_images + self.batch_size - 1) // self.batch_size} batches (batch size: {self.batch_size})"

            print(f"\n  API Usage Summary:")
            print(f"    Total images processed: {num_images}")
            print(f"    API calls made: {api_calls_info}")
            print(f"    Estimated cost: ${cost:.4f} USD")
            print(f"    ({pricing_info})")

            # Parse results for each cell
            for page_num, cell_num in cell_metadata:
                voter = self.parse_full_cell_ocr(cell_ocr_results, page_num, cell_num)

                if voter:
                    all_voters.append(voter)
                    print(f"  Page {page_num:3d} Cell {cell_num:2d}: {voter.get('name', 'N/A'):30s}")

        else:
            # Use Tesseract (single cell processing)
            for cell_file in cell_files:
                cell_path = os.path.join(cells_dir, cell_file)

                # Parse page and cell number from filename
                parts = cell_file.replace('.png', '').split('_')
                page_num = int(parts[0].replace('page', ''))
                cell_num = int(parts[1].replace('cell', ''))

                voter = self.extract_cell(cell_path, page_num, cell_num)

                if voter:
                    all_voters.append(voter)
                    print(f"  Page {page_num:3d} Cell {cell_num:2d}: {voter.get('name', 'N/A'):30s}")

        # Create DataFrame
        if all_voters:
            df = pd.DataFrame(all_voters)

            # Reorder columns
            column_order = ['page', 'cell', 'serial_no', 'voter_id', 'name',
                           'father_name', 'mother_name', 'address', 'date_of_birth', 'profession']

            for col in column_order:
                if col not in df.columns:
                    df[col] = ''

            df = df[column_order]

            # Save to CSV
            df.to_csv(output_csv, index=False, encoding='utf-8-sig')

            return len(df)

        return 0
