import argparse
import matplotlib.pyplot as plt
from models import *

parser = argparse.ArgumentParser()
parser.add_argument('-i', '--img_path', type=str, default='imgs/AK044_271-271.jpg')
parser.add_argument('--use_gpu', action='store_true', help='whether to use GPU')
parser.add_argument('-o', '--output', type=str, default='imgs_out/AK044_271-271')
parser.add_argument('-m', '--model', type=str, default='models/siggraph17-df00044c.pth', help='path to siggraph17 model weights file')
parser.add_argument('-l', '--legacy', type=bool, default=False, help='whether to use legacy ECCV model')
opt = parser.parse_args()

# load colorizers
if opt.legacy:
	colorizer = eccv16(pretrained=True).eval()
else:
	colorizer = siggraph17(pretrained=True, model_path=opt.model).eval()
# default size to process images is 256x256
# grab L channel in both original ("orig") and resized ("rs") resolutions
img = load_img(opt.img_path)
(tens_l_orig, tens_l_rs) = preprocess_img(img, HW=(256, 256))
# colorizer outputs 256x256 ab map, resize and concatenate to original L channel
img_bw = postprocess_tens(tens_l_orig, torch.cat((0*tens_l_orig, 0*tens_l_orig), dim=1))
out_img = postprocess_tens(tens_l_orig, colorizer(tens_l_rs).cpu())
plt.imsave(f'{opt.output}_col.jpg', out_img)
plt.imsave(f'{opt.output}_bw.jpg', img_bw)
print('Done!')