# auth: simon.mayer@onb.ac.at alexander.rabensteiner@onb.ac.at
# date: 2023-2-16
# desc: Provides functions to preprocess image to enhcance tesseract reults, currently applied:
#       - deskew image (code adapted from https://docs.opencv.org/4.7.0/d1/dee/tutorial_introduction_to_pca.html)
#       - crop image border regions
#       - upscale image

import os
import cv2 as cv
import numpy as np
import matplotlib.pyplot as plt
from math import atan2, cos, sin, sqrt, pi
# import editdistance
from IPython.display import HTML, Image, display


def _src_from_data(data):
    """Base64 encodes image bytes for inclusion in an HTML img element"""
    img_obj = Image(data=data)
    for bundle in img_obj._repr_mimebundle_():
        for mimetype, b64value in bundle.items():
            if mimetype.startswith('image/'):
                return f'data:{mimetype};base64,{b64value}'


def gallery(images, row_height='auto'):
    """Shows a set of images in a gallery that flexes with the width of the notebook.
    
    Parameters
    ----------
    images: list of str or bytes
        URLs or bytes of images to display

    row_height: str
        CSS height value to assign to all images. Set to 'auto' by default to show images
        with their native dimensions. Set to a value like '250px' to make all rows
        in the gallery equal height.
    """
    figures = []
    for image in images:
        if isinstance(image, bytes):
            src = _src_from_data(image)
            caption = ''
        else:
            src = image
            caption = f'<figcaption style="font-size: 0.6em">{image}</figcaption>'
        figures.append(f'''
            <figure style="margin: 5px !important;">
              <img src="{src}" style="height: {row_height}">
              {caption}
            </figure>
        ''')
    return HTML(data=f'''
        <div style="display: flex; flex-flow: row wrap; text-align: center;">
        {''.join(figures)}
        </div>
    ''')


def draw_axis(img, p_, q_, colour, scale):
    p = list(p_)
    q = list(q_)
    angle = atan2(p[1] - q[1], p[0] - q[0])  # angle in radians
    hypotenuse = sqrt((p[1] - q[1]) * (p[1] - q[1]) + (p[0] - q[0]) * (p[0] - q[0]))
    # Here we lengthen the arrow by a factor of scale
    q[0] = p[0] - scale * hypotenuse * cos(angle)
    q[1] = p[1] - scale * hypotenuse * sin(angle)
    cv.line(img, (int(p[0]), int(p[1])), (int(q[0]), int(q[1])), colour, 1, cv.LINE_AA)
    # create the arrow hooks
    p[0] = q[0] + 9 * cos(angle + pi / 4)
    p[1] = q[1] + 9 * sin(angle + pi / 4)
    cv.line(img, (int(p[0]), int(p[1])), (int(q[0]), int(q[1])), colour, 1, cv.LINE_AA)
    p[0] = q[0] + 9 * cos(angle - pi / 4)
    p[1] = q[1] + 9 * sin(angle - pi / 4)
    cv.line(img, (int(p[0]), int(p[1])), (int(q[0]), int(q[1])), colour, 1, cv.LINE_AA)


def get_orientation(pts, img, debug=False):
    sz = len(pts)
    data_pts = np.empty((sz, 2), dtype=np.float64)
    for i in range(data_pts.shape[0]):
        data_pts[i, 0] = pts[i, 0, 0]
        data_pts[i, 1] = pts[i, 0, 1]
    # Perform PCA
    mean = np.empty((0))
    mean, eigenvectors, eigenvalues = cv.PCACompute2(data_pts, mean)
    # Store the center of the object
    cntr = (int(mean[0, 0]), int(mean[0, 1]))
    p1 = (cntr[0] + 0.02 * eigenvectors[0, 0] * eigenvalues[0, 0],
          cntr[1] + 0.02 * eigenvectors[0, 1] * eigenvalues[0, 0])
    p2 = (cntr[0] - 0.02 * eigenvectors[1, 0] * eigenvalues[1, 0],
          cntr[1] - 0.02 * eigenvectors[1, 1] * eigenvalues[1, 0])
    if debug:
        cv.circle(img, cntr, 3, (255, 0, 255), 2)
        draw_axis(img, cntr, p1, (0, 255, 0), 1)
        draw_axis(img, cntr, p2, (255, 255, 0), 5)

    angle = atan2(eigenvectors[0, 1], eigenvectors[0, 0])  # orientation in radians
    ratio = eigenvalues[1, 0] / eigenvalues[0, 0]
    return angle, ratio


def deskew_image_pca(image, debug=False, debug_path=None):
    (orig_h, orig_w) = image.shape[:2]
    shrink_factor = 0.85
    (red_h, red_w) = (int(orig_h * shrink_factor), int(orig_w * shrink_factor))
    red_h_beg = (orig_h - red_h) // 2
    red_h_end = red_h_beg + red_h
    red_w_beg = (orig_w - red_w) // 2
    red_w_end = red_w_beg + red_w
    image_cutout = image[red_h_beg:red_h_end, red_w_beg:red_w_end]
    sharp = unsharp_mask(image_cutout, amount=2.0)
    gray = cv.cvtColor(sharp, cv.COLOR_BGR2GRAY)
    blur = cv.GaussianBlur(gray, (9, 9), 2)
    bw = cv.adaptiveThreshold(blur, 255, cv.ADAPTIVE_THRESH_GAUSSIAN_C, cv.THRESH_BINARY, 5, 4)
    inverse = cv.bitwise_not(bw)
    kernel = cv.getStructuringElement(cv.MORPH_CROSS, (10, 2))
    dilate = cv.dilate(inverse, kernel, iterations=3)

    contours, _ = cv.findContours(dilate, cv.RETR_LIST, cv.CHAIN_APPROX_SIMPLE)
    contours = sorted(contours, key=cv.contourArea, reverse=True)
    angles_ratios = []
    for i, c in enumerate(contours):
        # Calculate the area of each contour
        area = cv.contourArea(c)
        # Ignore contours that are too small or too large
        if area < 1e3 or 2e5 < area:
            continue
        # Draw each contour only for visualisation purposes
        if debug:
            cv.drawContours(gray, contours, i, (0, 0, 255), 2)
        # Find the orientation of each shape
        angle, ratio = get_orientation(c, gray, debug)
        angles_ratios.append((angle, ratio))
    if len(angles_ratios) > 0:
        angles, ratios = zip(*angles_ratios)
        norm_ratios = ratios / np.sum(ratios)
        # Inverse ratio as weights
        weights = [1 / ratio for ratio in norm_ratios]
        # Identify angles which are 90 degrees apart
        angles_sin = [np.sin(4 * angle) for angle in angles]
        # Perform weighted average
        angle_sin = np.average(angles_sin, weights=weights)
        angle = np.arcsin(angle_sin) / 4 * 180 / pi
        if debug:
            print(f'Weighted average angle is {round(angle, 2)} degrees')
            _, ax = plt.subplots()
            angles_deg = [angle * 180 / pi for angle in angles]
            ax.scatter(range(len(angles)), angles_deg, s=weights)
            ax.hlines(y=angle, xmin=0, xmax=len(angles), color='g')
            plt.show()
            horizontal_stack = np.concatenate((inverse, dilate, gray), axis=1)
            # cv.imshow('Thresholded image, dilated image, principal components drawn', horizontal_stack)
            # cv.waitKey()
            fp = debug_path.replace('img/', 'img/debug_img/').replace('.jpg', '_debug.jpg')
            cv.imwrite(fp, horizontal_stack)
    # Return original image if angle is larger than 2 degrees or if no large enough contours are found
    if len(angles_ratios) == 0 or abs(angle) > 10.0:
        if debug:
            print('Angle too large or no contours found, returning original')
        return image, 0
    center = (orig_w // 2, orig_h // 2)
    rot_mat = cv.getRotationMatrix2D(center, angle, 1.0)
    deskewed_image = cv.warpAffine(image, rot_mat, (orig_w, orig_h), flags=cv.INTER_CUBIC,
                                   borderMode=cv.BORDER_REPLICATE)
    return deskewed_image, angle


def unsharp_mask(image, kernel_size=(5, 5), sigma=1.0, amount=1.0, threshold=0):
    """Return a sharpened version of the image, using an unsharp mask."""
    blurred = cv.GaussianBlur(image, kernel_size, sigma)
    sharpened = float(amount + 1) * image - float(amount) * blurred
    sharpened = np.maximum(sharpened, np.zeros(sharpened.shape))
    sharpened = np.minimum(sharpened, 255 * np.ones(sharpened.shape))
    sharpened = sharpened.round().astype(np.uint8)
    if threshold > 0:
        low_contrast_mask = np.absolute(image - blurred) < threshold
        np.copyto(sharpened, image, where=low_contrast_mask)
    return sharpened


def analyse_deskew_method(img_filepaths, measured_angles, debug=False):
    calc_angles = []
    for image_fp in img_filepaths:
        print(f'Preprocessing image file: {image_fp}')
        image = cv.imread(image_fp)
        deskewed_img, angle = deskew_image_pca(image, debug=debug)
        calc_angles.append(angle)
        if debug:
            cv.imshow('Deskewed image', deskewed_img)
            cv.waitKey()

    # Analysis of automatically determined angles
    angle_diff = [angle - m_angle for (angle, m_angle) in zip(calc_angles, measured_angles)]
    error = np.linalg.norm(angle_diff)
    print('L2 distance between the two angle arrays is', error)
    _, ax = plt.subplots()
    ax.plot(calc_angles, 'o', label='automatic')
    ax.plot(measured_angles, 'or', label='measured')
    ax.set_xlabel('Images')
    ax.set_ylabel('Rotation angle [°]')
    ax.hlines(y=0, xmin=0, xmax=len(calc_angles) - 1, color='g')
    plt.title('Comparison of automatic and measured rotation angles')
    plt.legend()
    plt.show()


def get_cropping_corner_points(image, threshold_border=0.05, debug=False, debug_path=None):
    # Do some preprocessing magic to exract 'good' contours
    gray = cv.cvtColor(image, cv.COLOR_BGR2GRAY)
    _, bw = cv.threshold(gray, 50, 255, cv.THRESH_BINARY | cv.THRESH_OTSU)
    inverse = cv.bitwise_not(bw)
    kernel = cv.getStructuringElement(cv.MORPH_CROSS, (3, 3))
    dilate = cv.dilate(inverse, kernel, iterations=5)
    contours, _ = cv.findContours(dilate, cv.RETR_LIST, cv.CHAIN_APPROX_SIMPLE)

    # Define borderregion
    min_y = int(image.shape[0] * threshold_border)
    min_x = int(image.shape[1] * threshold_border)
    max_y = image.shape[0] - int(image.shape[0] * threshold_border)
    max_x = image.shape[1] - int(image.shape[1] * threshold_border)
    if debug:
        print(
            f'\tBorderRegion: upperleft_point={(min_x, min_y)} lowerright_point={(max_x, max_y)} '
            f'(total image size: {image.shape[1], image.shape[0]})')

    # Remove contours with single point in borderregion
    contours_surviving = []
    contours_removed = []  # for debug purposes
    check_if_point_in_inner_region = lambda x: min_x < x[0][0] < max_x and min_y < x[0][1] < max_y
    for contour in contours:
        all_contour_points_in_inner_region = all([check_if_point_in_inner_region(p) for p in contour])
        if all_contour_points_in_inner_region:
            contours_surviving.append(contour)
        else:
            contours_removed.append(contour)
    if debug:
        print(
            f'\tEliminated {len(contours_removed)} contours, {len(contours_surviving)} surviving '
            f'(total number contours: {len(contours)})')

    # Get the 4 min/max x/y points of all surviving contours to create a BoundingBox
    extract_x_values_from_contour = lambda cntr: [el for i, el in enumerate(cntr.flatten()) if i % 2 == 0]
    extract_y_values_from_contour = lambda cntr: [el for i, el in enumerate(cntr.flatten()) if i % 2 == 1]
    bb_precursor = np.array([
        [min([min(extract_x_values_from_contour(c)) for c in contours_surviving]),
         min([min(extract_y_values_from_contour(c)) for c in contours_surviving])],
        [max([max(extract_x_values_from_contour(c)) for c in contours_surviving]),
         max([max(extract_y_values_from_contour(c)) for c in contours_surviving])]
    ])
    boundingbox_of_surviving_contours = cv.boundingRect(bb_precursor)
    bb_p0 = (boundingbox_of_surviving_contours[0], boundingbox_of_surviving_contours[1])
    bb_p1 = (boundingbox_of_surviving_contours[0] + boundingbox_of_surviving_contours[2],
             boundingbox_of_surviving_contours[1] + boundingbox_of_surviving_contours[3])
    if debug:
        print(
            f'\tBoundingBox of surviving contours: {boundingbox_of_surviving_contours} -> upperleft_point={bb_p0} '
            f'lowerright_point{bb_p1}')

    # TODO evtl move following to seperate draw_nice_boxes_in_original_scan function
    if debug:
        for i, _ in enumerate(contours_removed):
            cv.drawContours(image, contours_removed, i, (255, 0, 255), thickness=1)
        for i, _ in enumerate(contours_surviving):
            cv.drawContours(image, contours_surviving, i, (0, 0, 255), thickness=1)
        cv.rectangle(image, bb_p0, bb_p1, (0, 255, 0), thickness=3)  # cropping
        cv.rectangle(image, (min_x, min_y), (max_x, max_y), (0, 255, 255), thickness=2)  # borderregion
        # cv.imshow(
        #     f'Image Cropping border_threshold={threshold_border} (green->Cropping, yellow->BorderRegion, '
        #     f'red=ContoursSurviving, violet=ContoursRejected)',
        #     image)
        fp = debug_path.replace('img/', 'img/debug_img/').replace('.jpg', '_debug_crop.jpg')
        cv.imwrite(fp, image)
        # cv.waitKey()
        # cv.destroyAllWindows()

    return bb_p0, bb_p1


def apply_gamma_correction(img, gamma=2.0):
    img_gray = cv.cvtColor(img, cv.COLOR_BGR2GRAY)
    look_up_table = np.empty((1, 256), np.uint8)
    for i in range(256):
        look_up_table[0, i] = np.clip(pow(i / 255.0, gamma) * 255.0, 0, 255)
    gamma_res = cv.LUT(img_gray, look_up_table)
    return gamma_res


def crop_bb_from_img(img, bb, margin=20):
    h, w = img.shape[:2]
    x_a, x_b = max(bb[0][0] - margin, 0), min(bb[1][0] + margin, w)
    y_a, y_b = max(bb[0][1] - margin, 0), min(bb[1][1] + margin, h)
    img_crop = img[y_a:y_b, x_a:x_b]
    return img_crop


def upscale_image(img, scaling=2):
    return cv.resize(img, (int(img.shape[1]*scaling), int(img.shape[0]*scaling)), cv.INTER_CUBIC)


def preprocess_pipeline(image, debug=False, debug_path=None):
    img_rot, img_angle = deskew_image_pca(image, debug=debug, debug_path=debug_path)
    img_bb = get_cropping_corner_points(img_rot, threshold_border=0.05, debug=debug, debug_path=debug_path)
    img_crop = crop_bb_from_img(img_rot, img_bb, margin=15)
    # img_upscld = upscale_image(img_crop, scaling=2)
    img_res = cv.cvtColor(img_crop, cv.COLOR_BGR2GRAY)
    # Following makes things gloablly worse, so I excluded these steps
    # img_erode = cv.erode(img_gray, cv.getStructuringElement(cv.MORPH_CROSS, (3, 3)), iterations=1)
    # img_blur = cv.GaussianBlur(img_bright_contrast, (3, 3), 2)
    # img_sharp = unsharp_mask(img_gray, kernel_size=(3, 3), sigma=1.5, amount=5.0, threshold=50)
    # img_denoise = cv.fastNlMeansDenoising(img_sharp, h=3)
    return img_res


if __name__ == '__main__':
    images = [
        '/home/simon/Downloads/Ruthenica_sample/tesseract_output/00000001.jpg',
        '/home/simon/Downloads/Ruthenica_sample/tesseract_output/00000003.jpg',
        '/home/simon/Downloads/Ruthenica_sample/tesseract_output/00000004.jpg'
    ]

    # images = [
    #     'img/test.jpg',
    #     'img/Laufon_009.jpg',
    #     'img/Laufon_012.jpg',
    #     'img/Laufon_013.jpg',
    #     'img/Laufon_055.jpg',
    #     'img/Laufon_094.jpg',
    #     'img/Laufon_110.jpg',
    #     'img/Dresden_008.jpg',
    #     'img/Dresden_020.jpg',
    #     'img/Dresden_037.jpg',
    #     'img/Dresden_063.jpg',
    #     'img/Dresden_094.jpg',
    #     'img/Dresden_132.jpg',
    #     'img/00002_1003102D_00000021.jpg',
    #     'img/00005_10026184_00000014.jpg',
    #     'img/00063_10029F7D_00000022.jpg'
    # ]
    # measured_angles = [-0.65, 0.1, -1.3, -0.4, 0.85, 1.2,
    #                    -0.1, -0.7, 0.5, -0.4, -0.7, -1.3,
    #                    0.5, 0.3, -0.7]
    # analyse_deskew_method(images, measured_angles, debug=False)
    # ground_truth_fps = [
    #     'ground_truths/GT_Laufon_009.txt',
    #     'ground_truths/GT_Laufon_012.txt',
    #     'ground_truths/GT_Laufon_013.txt',
    #     'ground_truths/GT_Laufon_055.txt',
    #     'ground_truths/GT_Laufon_094.txt',
    #     'ground_truths/GT_Laufon_110.txt',
    #     'ground_truths/GT_Dresden_008.txt',
    #     'ground_truths/GT_Dresden_020.txt',
    #     'ground_truths/GT_Dresden_037.txt',
    #     'ground_truths/GT_Dresden_063.txt',
    #     'ground_truths/GT_Dresden_094.txt',
    #     'ground_truths/GT_Dresden_132.txt',
    #     'ground_truths/GT_00002_1003102D_00000021.txt',
    #     'ground_truths/GT_00005_10026184_00000014.txt',
    #     'ground_truths/GT_00063_10029F7D_00000022.txt'
    # ]

    # ground_truths = [open(fp).read() for fp in ground_truth_fps]
    # levenshtein_distances = []

    for img_path in images:

        img_path_preprocessed = f'{img_path.replace(".jpg", "")}_preprocessed.jpg'
        ocred_path_base_preprocessed = img_path_preprocessed.replace('.jpg', '')

        print(f'Processing image {img_path}')
        input_image = cv.imread(img_path)
        preprocessed_img = preprocess_pipeline(input_image, debug=True)
        # cv.imshow('Preprocessed image, rotated, cropped and resized', preprocessed_img)
        # cv.waitKey()

        cv.imwrite(img_path_preprocessed, preprocessed_img)
        os.system(f'tesseract -l ukr --dpi 300 --psm 4 {img_path_preprocessed} {ocred_path_base_preprocessed} txt')

    # print('Levenshtein distances are:', levenshtein_distances)
    # print('L2 norm is:', np.linalg.norm(levenshtein_distances))
