import gradio as gr

from iiif_utils import get_imgurls_from_manifesturl, create_paths_from_iiifurls, download_images, download_images_multithreded, get_json_by_url

import torch
from torchvision import transforms
from PIL import Image
import json
import os


# ML code from https://huggingface.co/spaces/pytorch/ResNet
model = torch.hub.load('pytorch/vision:v0.10.0', 'resnet18', pretrained=True)
model.eval()

def inference(input_image):
    preprocess = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])
    input_tensor = preprocess(input_image)
    input_batch = input_tensor.unsqueeze(0) # create a mini-batch as expected by the model

    # move the input and model to GPU for speed if available
    if torch.cuda.is_available():
        input_batch = input_batch.to('cuda')
        model.to('cuda')

    with torch.no_grad():
        output = model(input_batch)
    # The output has unnormalized scores. To get probabilities, you can run a softmax on it.
    probabilities = torch.nn.functional.softmax(output[0], dim=0)

    # Read the categories
    with open("resnet18_categories.txt", "r") as f:
        categories = [s.strip() for s in f.readlines()]
    # Show top categories per image
    top5_prob, top5_catid = torch.topk(probabilities, 5)
    result = {}
    for i in range(top5_prob.size(0)):
        result[categories[top5_catid[i]]] = top5_prob[i].item()
    return result


# gradio related code
samples = []

def get_new_examples(url):
    # get and save manifest for later export
    manifest = get_json_by_url(url)
    if not os.path.isdir('images'):
    	os.makedirs('images')
    with open('current_manifest.json', 'w+') as f:
        f.write(json.dumps(manifest, indent=4))
    # get the images from the manifest
    img_urls = get_imgurls_from_manifesturl(url)
    img_paths = create_paths_from_iiifurls(img_urls, 'images')
    download_images_multithreded(img_urls, img_paths, nb_threads=10)
    return [[p] for p in img_paths]

def update_examples(url):
    global samples
    samples = get_new_examples(url)
    return gr.Dataset.update(samples=samples)

def load_example(example_id):
    global samples
    return samples[example_id][0]

def apply_model(example_id):
    global samples
    img_path = samples[example_id][0]
    img = Image.open(img_path)
    img = img.convert('RGB')
    res = inference(img)
    return res


with gr.Blocks() as demo:
    with gr.Row():
        with gr.Column():
            link = gr.Textbox(label='Link to iiif manifest')
            load_manifest = gr.Button(value="Load manifest")
            preview = gr.Image(label="Selected image", height=400)
            examples = gr.Dataset(samples=samples, components=[preview], type="index")

        with gr.Column():
            res_labels = gr.Label(type="confidences",num_top_classes=5)

        load_manifest.click(update_examples, inputs=link, outputs=examples)
        examples.click(load_example, inputs=examples, outputs=preview)
        examples.click(apply_model, inputs=examples, outputs=res_labels)

demo.launch()
