import argparse
import torch
from models import *
import matplotlib.pyplot as plt
import torch.optim as optim

from torch.utils.data import random_split
from torch.utils.data import DataLoader
import time
import copy
import sys

parser = argparse.ArgumentParser()
parser.add_argument('-i', '--img_path', type=str, default='imgs/AK044_271.jpg')
parser.add_argument('--use_gpu', action='store_true', help='whether to use GPU')
parser.add_argument('-o', '--save_prefix', type=str, default='imgs_out/AK044_271', help='will save into this file with {eccv16.png, siggraph17.png} suffixes')
# parser.add_argument('-m', '--model', type=str, default='siggraph')
opt = parser.parse_args()
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")


def train_model(model, criterion, optimizer, scheduler, num_epochs=25):
    since = time.time()

    best_model_wts = copy.deepcopy(model.state_dict())
    best_loss = 100

    for epoch in range(num_epochs):
        print('Epoch {}/{}'.format(epoch, num_epochs - 1))
        print('-' * 10)

        # Each epoch has a training and validation phase
        for phase in ['train', 'test']:
            if phase == 'train':
                model.train()  # Set model to training mode
            else:
                model.eval()   # Set model to evaluate mode

            running_loss = 0.0
            # running_corrects = 0

            # Iterate over data.
            for i, data in enumerate(dataloaders[phase]):
                inputs = data['L'].to(device)
                labels = data['ab'].to(device)

                # zero the parameter gradients
                optimizer.zero_grad()

                # forward
                # track history if only in train
                with torch.set_grad_enabled(phase == 'train'):
                    outputs = model(inputs)
                    # _, preds = torch.max(outputs, 1)
                    loss = criterion(outputs, labels)

                    # backward + optimize only if in training phase
                    if phase == 'train':
                        loss.backward()
                        optimizer.step()

                # statistics
                running_loss += loss.item() * inputs.size(0)
                sys.stdout.write('\r' 'Id ' + str(i) + ', Running loss: ' + str(running_loss))
                # if ((i == 5) and (phase == 'test')) or (i == 20):
                #     break

                # running_corrects += torch.sum(preds == labels.data)
            if phase == 'train':
                scheduler.step()

            epoch_loss = running_loss / dataset_sizes[phase]
            # epoch_acc = running_corrects.double() / dataset_sizes[phase]

            print('{} Loss: {:.4f}'.format(
                phase, epoch_loss))

            # deep copy the model
            if phase == 'test' and epoch_loss < best_loss:
                best_loss = epoch_loss
                best_model_wts = copy.deepcopy(model.state_dict())

        print()

    time_elapsed = time.time() - since
    print('Training complete in {:.0f}m {:.0f}s'.format(
        time_elapsed // 60, time_elapsed % 60))
    print('Best val Loss: {:4f}'.format(best_loss))

    # load best model weights
    model.load_state_dict(best_model_wts)
    return model


# load colorizer
colorizer = siggraph17()

for name, param in colorizer.named_parameters():
    print(name)

postcard_dataset = ColDataset(csv_file='models/colored_postcards/akon_postcards_public_domain.csv',
                                     root_dir='models/colored_postcards/low_res_imgs')
(train_set, test_set) = random_split(postcard_dataset, [round(0.7*len(postcard_dataset)),
                                                        round(0.3*len(postcard_dataset))])
train_loader = DataLoader(train_set, shuffle=True, batch_size=10)  # batch_size = 5,10,...
test_loader = DataLoader(test_set, batch_size=10)
dataloaders = {'train': train_loader, 'test': test_loader}
dataset_sizes = {'train': len(train_loader),  'test': len(test_loader)}
criterion = nn.SmoothL1Loss()
optimizer = optim.Adam(colorizer.parameters(), lr=0.001)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=7, gamma=0.1)
#
# colorizer = train_model(colorizer, criterion, optimizer, scheduler)
# PATH = './models/col_test.pth'
# torch.save(colorizer.state_dict(), PATH)
#
#
# img = load_img('models/colored_postcards/imgs/AK001_011.jpg')
# low_img = load_img('models/colored_postcards/low_res_imgs/AK001_011.jpg')
# (tens_l_orig, tens_l_rs) = preprocess_img(img, HW=(256, 256))
# (t_l_orig, t_l_rs) = preprocess_img(low_img, HW=(256, 256))
# (tens_ab_orig, tens_ab_rs) = get_ab_channel(img, hw=(256, 256))
#
# colorizer_siggraph17.train()
# target = tens_ab_rs
# criterion = nn.SmoothL1Loss()
# optimizer = optim.Adam(colorizer_siggraph17.parameters(), lr=0.01)
#
# # in the training loop:
# optimizer.zero_grad()
# output = colorizer_siggraph17(t_l_rs)
# loss = criterion(output, target)
# print(loss)
# loss.backward()
# optimizer.step()
#
# img_bw = postprocess_tens(tens_l_orig, torch.cat((0*tens_l_orig, 0*tens_l_orig), dim=1))
# out_img_siggraph17 = postprocess_tens(tens_l_orig, output)
# # plt.imsave('test_bw.jpg', img_bw)
# plt.imsave('test_output.jpg', out_img_siggraph17)
