from .colored_postcards.download_color import *
from .util import *
import os
import torch
import pandas as pd
from skimage import io, transform
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader
# from torchvision import transforms, utils


class ColDataset(Dataset):
    """Colored Postcards dataset"""

    def __init__(self, csv_file, root_dir, transform=None):
        """
        :param csv_file (string): Path to the csv file with image names
        :param root_dir (string): Directory with all the images
        :param transform (callable, optional): Optional transform to be applied on a sample
        """
        df = get_raw_data(csv_file)
        is_color = df['color'] == 'True'
        df_color = df[is_color]
        self.id_frame = df_color
        self.root_dir = root_dir
        self.transform = transform

    def __len__(self):
        return len(self.id_frame)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()

        img_id = self.id_frame.iloc[idx, 0]
        img_path = os.path.join(self.root_dir, img_id) + '.jpg'
        image = load_img(img_path)
        image_lab = color.rgb2lab(image)
        tens_l = torch.Tensor(image_lab[:, :, 0])[None, :, :]
        tens_ab = torch.Tensor(image_lab[:, :, 1:])[:, :].transpose(2, 1).transpose(1, 0)
        sample = {'L': tens_l, 'ab': tens_ab}

        if self.transform:
            sample = self.transform(sample)

        return sample

