{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3035f8a5-3aa6-4b11-8507-1e61e846fe23",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "!pip install -r requirements.txt"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "e098af0e-21cb-42ce-84bd-fcb49e5d7f81",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "import pathlib\n",
    "import pandas as pd\n",
    "import numpy as np\n",
    "import requests\n",
    "import cv2\n",
    "from tqdm.notebook import tqdm\n",
    "import re\n",
    "from PIL import Image"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "7548cbea-2330-45d7-91b0-7f8644b0e6a0",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "def get_content(url):\n",
    "    retries = 0\n",
    "    while retries < 3:\n",
    "        try:\n",
    "            resp = requests.get(url)\n",
    "            if resp.status_code == 200:\n",
    "                return resp.content\n",
    "            else:\n",
    "                retries += 1\n",
    "        except requests.exceptions.Timeout:\n",
    "            retries += 1\n",
    "        except requests.exceptions.TooManyRedirects as e:\n",
    "            raise SystemExit(e)\n",
    "        except requests.exceptions.RequestException as e:\n",
    "            raise SystemExit(e)\n",
    "\n",
    "def get_category(entry):\n",
    "    if entry in ['A', 'B', 'C']:\n",
    "        return entry\n",
    "    else:\n",
    "        return 'N'\n",
    "\n",
    "def display_image_from_series(df, num_images=16):\n",
    "    images_so_far = 0\n",
    "    cols = 8\n",
    "    rows = int(np.ceil(num_images/cols))\n",
    "    fig = plt.figure(figsize=(int(cols*2.5),int(rows*2.5)))\n",
    "    \n",
    "    for i, (filename, year) in enumerate(df[['filename', 'years']].values):\n",
    "        images_so_far += 1\n",
    "        ax = plt.subplot(rows, cols, images_so_far)\n",
    "        ax.axis('off')\n",
    "        ax.set_title(i)\n",
    "        image = Image.open(f'img/ABO/{year}/no_category/{filename}')\n",
    "        plt.imshow(image)\n",
    "        if images_so_far == num_images:\n",
    "            return"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bff4b449-2875-4563-a6d5-fe1c349c87f0",
   "metadata": {},
   "outputs": [],
   "source": [
    "pred_1501_1549 = pd.read_csv('data/predictions/pred_1501_1549.csv')\n",
    "pred_1550_1599 = pd.read_csv('data/predictions/pred_1550_1599.csv')\n",
    "pred_1600_1650 = pd.read_csv('data/predictions/pred_1600_1650.csv')\n",
    "pred_1651_1699 = pd.read_csv('data/predictions/pred_1651_1699.csv')\n",
    "pred_1700_1738 = pd.read_csv('data/predictions/pred_1700_1738.csv')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "923560af-a53f-4ae0-b37c-17793d298d4f",
   "metadata": {},
   "outputs": [],
   "source": [
    "pred_1501_1549['years'] = '1501_1549'\n",
    "pred_1550_1599['years'] = '1550_1599'\n",
    "pred_1600_1650['years'] = '1600_1650'\n",
    "pred_1651_1699['years'] = '1651_1699'\n",
    "pred_1700_1738['years'] = '1700_1738'\n",
    "\n",
    "all_frames = [pred_1501_1549, pred_1550_1599, pred_1600_1650, pred_1651_1699, pred_1700_1738]\n",
    "all_pred = pd.concat(all_frames)\n",
    "all_pred.to_csv('data/predictions/all_pred_1501_1738.csv')\n",
    "print(len(all_pred))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "25bd6f96-6386-461a-9e6d-f86b12ab824b",
   "metadata": {},
   "outputs": [],
   "source": [
    "all_pred['probability'] = all_pred['probability'].apply(lambda x: json.loads(x))\n",
    "all_pred['p_A'] = all_pred['probability'].apply(lambda x: x[0])\n",
    "all_pred['p_B'] = all_pred['probability'].apply(lambda x: x[1])\n",
    "all_pred['p_C'] = all_pred['probability'].apply(lambda x: x[2])\n",
    "all_pred['p_N'] = all_pred['probability'].apply(lambda x: x[3])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a96c3309-0f50-4d00-b321-e87245d49ef5",
   "metadata": {},
   "outputs": [],
   "source": [
    "def create_GT_addition(GT_df, candidate_df, num_additions):\n",
    "    return_lis = []\n",
    "    running_index = GT_df.iloc[-1]['Unnamed: 0']\n",
    "    for i, row in enumerate(candidate_df.values):\n",
    "        if len(return_lis) == num_additions:\n",
    "            print('Finishing at index', i)\n",
    "            break\n",
    "        bc = row[0].split('_')[0]\n",
    "        if GT_df['Strichcode'].str.contains(bc).any():\n",
    "            continue\n",
    "        page = row[0].split('_')[1].replace('.jpg', '')\n",
    "        running_index += 1\n",
    "        new_row = {\n",
    "                'Unnamed: 0': running_index,\n",
    "                'Strichcode': bc,\n",
    "                'Link': '',\n",
    "                'Variante': row[1],\n",
    "                'Farbe': '',\n",
    "                'Erhaltungsgrad': '',\n",
    "                'Bem.': '',\n",
    "                'Seite': int(page),\n",
    "                'Image URLs': [f'https://iiif.onb.ac.at/images/ABO/{bc}/{page}/full/full/0/native.jpg']\n",
    "            }\n",
    "        return_lis.append(new_row)\n",
    "    return pd.DataFrame(return_lis)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "830fe823-75c0-4e95-9dc5-55a7fedede28",
   "metadata": {},
   "outputs": [],
   "source": [
    "GT_df = pd.read_csv('data/groundtruth/BE_GT_v2.csv')\n",
    "\n",
    "GT_addition_A = create_GT_addition(GT_df, all_pred[all_pred['p_A'] > 0.95], 79)\n",
    "GT_addition_C = create_GT_addition(GT_df, all_pred[all_pred['p_C'] > 0.95], 150)\n",
    "new_GT = pd.concat([GT_df, GT_addition_A, GT_addition_C], ignore_index=True)\n",
    "new_GT = new_GT.drop('Unnamed: 0', axis=1)\n",
    "new_GT.to_csv('data/groundtruth/BE_GT_v3.csv')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "94ccf8d2-dd0a-469e-9f94-349ab546deda",
   "metadata": {},
   "outputs": [],
   "source": [
    "display_image_from_series(all_pred[all_pred['p_A'] > 0.95], num_images=85)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "18c4f053-3b42-4ee7-97d4-91e8e71570b4",
   "metadata": {},
   "source": [
    "## Download new GT with square cutout and reduced resolution"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "3cad43ef-7c46-4949-ae60-5afd06a7bb80",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "acf5edc590724e54a41c72adbcc97ad5",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/804 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "img_dir = pathlib.Path('img')\n",
    "if not img_dir.exists():\n",
    "    img_dir.mkdir()\n",
    "GT_path = img_dir.joinpath('GT_square')\n",
    "if not GT_path.exists():\n",
    "    GT_path.mkdir()\n",
    "GT_csv = pd.read_csv('data/groundtruth/BE_GT_v3.csv')\n",
    "for category, URL_lis, barcode in zip(tqdm(GT_csv['Variante']), GT_csv['Image URLs'], GT_csv['Strichcode']):\n",
    "    url_lis = eval(URL_lis)\n",
    "    cat = get_category(category)\n",
    "    GT_cat_path = GT_path.joinpath(cat)\n",
    "    if not GT_cat_path.exists():\n",
    "        GT_cat_path.mkdir()\n",
    "    for url in url_lis:\n",
    "        page_number = re.findall('/(.{8,12})/full/full', url)[0]\n",
    "        filename = f'{barcode}_{page_number}.jpg'\n",
    "        if 'REPO' not in url:\n",
    "            filepath = GT_cat_path.joinpath(filename)\n",
    "            if not filepath.exists():\n",
    "                url = url.replace('full/full', 'square/256,')\n",
    "                img_content = get_content(url)\n",
    "                open(filepath, 'wb').write(img_content)\n",
    "        else:\n",
    "            filepath = GT_cat_path.joinpath(filename.replace('.jpg.', '.'))\n",
    "            if not filepath.exists():\n",
    "                resp = requests.get(url, stream=True).raw\n",
    "                img = np.asarray(bytearray(resp.read()), dtype='uint8')\n",
    "                img = cv2.imdecode(img, cv2.IMREAD_COLOR)\n",
    "                img_width = img.shape[1]\n",
    "                img_height = img.shape[0]\n",
    "                x = 0\n",
    "                y = int((img_height - img_width)/2)\n",
    "                w = img_width\n",
    "                h = img_width\n",
    "                square_img_content = img[y:y+h, x:x+w]\n",
    "                resized_img = cv2.resize(square_img_content, (256, 256))\n",
    "                cv2.imwrite(filepath.as_posix(), resized_img)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "dd1ab4b1-e938-4337-856b-f04ae7b367ac",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}