#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
PERFECT FINAL ANNOTATION - Proper spacing and alignment
"""

import cv2
import numpy as np
import os

class PerfectAnnotator:
    def __init__(self):
        # Load all templates
        self.templates = {
            'name': cv2.imread('wider_templates/name_label.png', cv2.IMREAD_GRAYSCALE),
            'voter_id': cv2.imread('wider_templates/voter_id_label.png', cv2.IMREAD_GRAYSCALE),
            'father': cv2.imread('wider_templates/father_label.png', cv2.IMREAD_GRAYSCALE),
            'mother': cv2.imread('wider_templates/mother_label.png', cv2.IMREAD_GRAYSCALE),
            'profession': cv2.imread('wider_templates/profession_label.png', cv2.IMREAD_GRAYSCALE),
            'address': cv2.imread('wider_templates/address_label.png', cv2.IMREAD_GRAYSCALE),
            'dob': cv2.imread('wider_templates/dob_label.png', cv2.IMREAD_GRAYSCALE),
        }

        # Colors
        self.COLOR_LABEL = (0, 0, 255)      # RED for labels
        self.COLOR_VALUE = (255, 0, 0)      # BLUE for values
        self.COLOR_SERIAL = (0, 255, 0)     # GREEN for serial
        self.COLOR_DOB = (255, 255, 0)      # CYAN for dob

    def find_template(self, image_gray, template, threshold=0.7):
        """Find template in image"""
        if template is None:
            return None

        result = cv2.matchTemplate(image_gray, template, cv2.TM_CCOEFF_NORMED)
        min_val, max_val, min_loc, max_loc = cv2.minMaxLoc(result)

        if max_val >= threshold:
            h, w = template.shape
            return (*max_loc, w, h, max_val)
        return None

    def annotate_cell(self, cell_path, output_path):
        """Annotate with proper spacing and alignment"""
        cell_color = cv2.imread(cell_path)
        cell_gray = cv2.cvtColor(cell_color, cv2.COLOR_BGR2GRAY)
        cell_h, cell_w = cell_color.shape[:2]

        annotated = cell_color.copy()

        # Find all templates
        positions = {}
        for template_name, template in self.templates.items():
            match = self.find_template(cell_gray, template)
            if match:
                positions[template_name] = match

        # VERTICAL PADDING: Add padding above and below to capture full text
        V_PAD_TOP = 3
        V_PAD_BOTTOM = 8

        # HORIZONTAL GAP: Space between label and value
        H_GAP = 5

        # 1. NAME + SERIAL
        if 'name' in positions:
            label_x, label_y, label_w, label_h, score = positions['name']

            # RED box for "নাম:" label
            cv2.rectangle(annotated, (label_x, label_y), (label_x+label_w, label_y+label_h),
                         self.COLOR_LABEL, 3)
            cv2.putText(annotated, "name_label", (label_x, label_y-10),
                       cv2.FONT_HERSHEY_SIMPLEX, 0.4, self.COLOR_LABEL, 2)

            # GREEN box for serial_no (LEFT of name) - BETTER VERTICAL ALIGNMENT
            serial_x = 0
            serial_y = label_y - V_PAD_TOP  # Start above for better capture
            serial_w = label_x - 3  # Leave gap
            serial_h = label_h + V_PAD_TOP + V_PAD_BOTTOM  # Taller to capture full text

            if serial_w > 10:
                cv2.rectangle(annotated, (serial_x, serial_y), (serial_x+serial_w, serial_y+serial_h),
                             self.COLOR_SERIAL, 3)
                cv2.putText(annotated, "serial_no", (serial_x+2, serial_y+18),
                           cv2.FONT_HERSHEY_SIMPLEX, 0.35, self.COLOR_SERIAL, 2)

            # BLUE box for name_value (RIGHT of name label) - WITH PROPER GAP
            name_x = label_x + label_w + H_GAP  # Gap after label
            name_y = label_y - V_PAD_TOP
            name_w = cell_w - name_x - 5
            name_h = label_h + V_PAD_TOP + V_PAD_BOTTOM

            cv2.rectangle(annotated, (name_x, name_y), (name_x+name_w, name_y+name_h),
                         self.COLOR_VALUE, 3)
            cv2.putText(annotated, "name_value", (name_x+2, name_y+18),
                       cv2.FONT_HERSHEY_SIMPLEX, 0.35, self.COLOR_VALUE, 2)

        # 2. VOTER ID
        if 'voter_id' in positions:
            label_x, label_y, label_w, label_h, score = positions['voter_id']

            cv2.rectangle(annotated, (label_x, label_y), (label_x+label_w, label_y+label_h),
                         self.COLOR_LABEL, 3)
            cv2.putText(annotated, "voter_id_label", (label_x, label_y-10),
                       cv2.FONT_HERSHEY_SIMPLEX, 0.4, self.COLOR_LABEL, 2)

            value_x = label_x + label_w + H_GAP
            value_y = label_y - V_PAD_TOP
            value_w = cell_w - value_x - 5
            value_h = label_h + V_PAD_TOP + V_PAD_BOTTOM

            cv2.rectangle(annotated, (value_x, value_y), (value_x+value_w, value_y+value_h),
                         self.COLOR_VALUE, 3)
            cv2.putText(annotated, "voter_id_value", (value_x+2, value_y+18),
                       cv2.FONT_HERSHEY_SIMPLEX, 0.35, self.COLOR_VALUE, 2)

        # 3. FATHER
        if 'father' in positions:
            label_x, label_y, label_w, label_h, score = positions['father']

            cv2.rectangle(annotated, (label_x, label_y), (label_x+label_w, label_y+label_h),
                         self.COLOR_LABEL, 3)
            cv2.putText(annotated, "father_label", (label_x, label_y-10),
                       cv2.FONT_HERSHEY_SIMPLEX, 0.4, self.COLOR_LABEL, 2)

            value_x = label_x + label_w + H_GAP
            value_y = label_y - V_PAD_TOP
            value_w = cell_w - value_x - 5
            value_h = label_h + V_PAD_TOP + V_PAD_BOTTOM

            cv2.rectangle(annotated, (value_x, value_y), (value_x+value_w, value_y+value_h),
                         self.COLOR_VALUE, 3)
            cv2.putText(annotated, "father_value", (value_x+2, value_y+18),
                       cv2.FONT_HERSHEY_SIMPLEX, 0.35, self.COLOR_VALUE, 2)

        # 4. MOTHER
        if 'mother' in positions:
            label_x, label_y, label_w, label_h, score = positions['mother']

            cv2.rectangle(annotated, (label_x, label_y), (label_x+label_w, label_y+label_h),
                         self.COLOR_LABEL, 3)
            cv2.putText(annotated, "mother_label", (label_x, label_y-10),
                       cv2.FONT_HERSHEY_SIMPLEX, 0.4, self.COLOR_LABEL, 2)

            value_x = label_x + label_w + H_GAP
            value_y = label_y - V_PAD_TOP
            value_w = cell_w - value_x - 5
            value_h = label_h + V_PAD_TOP + V_PAD_BOTTOM

            cv2.rectangle(annotated, (value_x, value_y), (value_x+value_w, value_y+value_h),
                         self.COLOR_VALUE, 3)
            cv2.putText(annotated, "mother_value", (value_x+2, value_y+18),
                       cv2.FONT_HERSHEY_SIMPLEX, 0.35, self.COLOR_VALUE, 2)

        # 5. PROFESSION + DOB (same line, separate clearly)
        if 'profession' in positions and 'dob' in positions:
            prof_x, prof_y, prof_w, prof_h, prof_score = positions['profession']
            dob_x, dob_y, dob_w, dob_h, dob_score = positions['dob']

            # PROFESSION label
            cv2.rectangle(annotated, (prof_x, prof_y), (prof_x+prof_w, prof_y+prof_h),
                         self.COLOR_LABEL, 3)
            cv2.putText(annotated, "prof_label", (prof_x, prof_y-10),
                       cv2.FONT_HERSHEY_SIMPLEX, 0.4, self.COLOR_LABEL, 2)

            # PROFESSION value (stops before DOB label with gap)
            prof_val_x = prof_x + prof_w + H_GAP
            prof_val_y = prof_y - V_PAD_TOP
            prof_val_w = dob_x - prof_val_x - H_GAP  # Gap before DOB
            prof_val_h = prof_h + V_PAD_TOP + V_PAD_BOTTOM

            cv2.rectangle(annotated, (prof_val_x, prof_val_y), (prof_val_x+prof_val_w, prof_val_y+prof_val_h),
                         self.COLOR_VALUE, 3)
            cv2.putText(annotated, "prof_val", (prof_val_x+2, prof_val_y+18),
                       cv2.FONT_HERSHEY_SIMPLEX, 0.35, self.COLOR_VALUE, 2)

            # DOB label (CYAN)
            cv2.rectangle(annotated, (dob_x, dob_y), (dob_x+dob_w, dob_y+dob_h),
                         self.COLOR_DOB, 3)
            cv2.putText(annotated, "dob_label", (dob_x, dob_y-10),
                       cv2.FONT_HERSHEY_SIMPLEX, 0.4, self.COLOR_DOB, 2)

            # DOB value (CYAN)
            dob_val_x = dob_x + dob_w + H_GAP
            dob_val_y = dob_y - V_PAD_TOP
            dob_val_w = cell_w - dob_val_x - 5
            dob_val_h = dob_h + V_PAD_TOP + V_PAD_BOTTOM

            cv2.rectangle(annotated, (dob_val_x, dob_val_y), (dob_val_x+dob_val_w, dob_val_y+dob_val_h),
                         self.COLOR_DOB, 3)
            cv2.putText(annotated, "dob_val", (dob_val_x+2, dob_val_y+18),
                       cv2.FONT_HERSHEY_SIMPLEX, 0.35, self.COLOR_DOB, 2)

        elif 'profession' in positions:
            # Only profession
            prof_x, prof_y, prof_w, prof_h, prof_score = positions['profession']

            cv2.rectangle(annotated, (prof_x, prof_y), (prof_x+prof_w, prof_y+prof_h),
                         self.COLOR_LABEL, 3)
            cv2.putText(annotated, "prof_label", (prof_x, prof_y-10),
                       cv2.FONT_HERSHEY_SIMPLEX, 0.4, self.COLOR_LABEL, 2)

            prof_val_x = prof_x + prof_w + H_GAP
            prof_val_y = prof_y - V_PAD_TOP
            prof_val_w = cell_w - prof_val_x - 5
            prof_val_h = prof_h + V_PAD_TOP + V_PAD_BOTTOM

            cv2.rectangle(annotated, (prof_val_x, prof_val_y), (prof_val_x+prof_val_w, prof_val_y+prof_val_h),
                         self.COLOR_VALUE, 3)
            cv2.putText(annotated, "prof+dob_val", (prof_val_x+2, prof_val_y+18),
                       cv2.FONT_HERSHEY_SIMPLEX, 0.35, self.COLOR_VALUE, 2)

        # 6. ADDRESS (full remaining height)
        if 'address' in positions:
            label_x, label_y, label_w, label_h, score = positions['address']

            cv2.rectangle(annotated, (label_x, label_y), (label_x+label_w, label_y+label_h),
                         self.COLOR_LABEL, 3)
            cv2.putText(annotated, "address_label", (label_x, label_y-10),
                       cv2.FONT_HERSHEY_SIMPLEX, 0.4, self.COLOR_LABEL, 2)

            value_x = label_x + label_w + H_GAP
            value_y = label_y - V_PAD_TOP
            value_w = cell_w - value_x - 5
            value_h = cell_h - value_y - 5  # To bottom of cell

            cv2.rectangle(annotated, (value_x, value_y), (value_x+value_w, value_y+value_h),
                         self.COLOR_VALUE, 3)
            cv2.putText(annotated, "address_value", (value_x+2, value_y+18),
                       cv2.FONT_HERSHEY_SIMPLEX, 0.35, self.COLOR_VALUE, 2)

        cv2.imwrite(output_path, annotated)
        return len(positions)

def main():
    import sys

    # Get input folder from command line or use default
    if len(sys.argv) > 1:
        input_folder = sys.argv[1]
    else:
        input_folder = 'page6_cells'

    # Get output folder from command line or create default
    if len(sys.argv) > 2:
        output_folder = sys.argv[2]
    else:
        output_folder = f"{input_folder}_annotated"

    if not os.path.exists(input_folder):
        print(f"ERROR: Input folder not found: {input_folder}")
        print(f"\nUsage: python annotate_final_perfect.py [input_folder] [output_folder]")
        sys.exit(1)

    # Run
    annotator = PerfectAnnotator()
    cell_files = sorted([f for f in os.listdir(input_folder) if f.endswith('.png') and 'annotated' not in f])

    print(f"\nPERFECT ANNOTATION with proper spacing and alignment")
    print(f"Input folder: {input_folder}")
    print(f"Output folder: {output_folder}")
    print(f"Total cells: {len(cell_files)}")
    print("\nColor code:")
    print("  GREEN = serial_no")
    print("  RED   = all field labels")
    print("  BLUE  = most values")
    print("  CYAN  = dob label/value\n")

    os.makedirs(output_folder, exist_ok=True)

    for cell_file in cell_files:
        cell_path = f'{input_folder}/{cell_file}'
        output_path = f'{output_folder}/{cell_file}'

        num_fields = annotator.annotate_cell(cell_path, output_path)
        print(f"{cell_file}: {num_fields} fields → {output_path}")

    print(f"\n✓ Done! Check: {output_folder}/")

if __name__ == '__main__':
    main()
