diff --git a/colorization/demo_release.py b/colorization/demo_release.py index 94b6e24e0626841e1d31f82355a76d77f5ad3fd5..bd4d12221f229fbda3a25cbe8efa6d388c94ed78 100644 --- a/colorization/demo_release.py +++ b/colorization/demo_release.py @@ -14,7 +14,7 @@ opt = parser.parse_args() use_siggraph = opt.model == 'siggraph' # load colorizers if use_siggraph: - colorizer_siggraph17 = siggraph17(pretrained=True).eval() + colorizer_siggraph17 = siggraph17(pretrained=True, model_path='models/col_test3.pth').eval() if not use_siggraph: colorizer_eccv16 = eccv16(pretrained=True).eval() # if(opt.use_gpu): @@ -39,7 +39,7 @@ if use_siggraph: if not use_siggraph: plt.imsave('%s_eccv16.png'%opt.save_prefix, out_img_eccv16) if use_siggraph: - plt.imsave('%s_siggraph17.png'%opt.save_prefix, out_img_siggraph17) + plt.imsave('%s.jpg'%opt.save_prefix, out_img_siggraph17) # plt.figure(figsize=(12,8)) # plt.subplot(2,2,1) diff --git a/colorization/models/ColDataset.py b/colorization/models/ColDataset.py new file mode 100644 index 0000000000000000000000000000000000000000..fe37e812303a5de3361012378953fd0b53173a62 --- /dev/null +++ b/colorization/models/ColDataset.py @@ -0,0 +1,48 @@ +from .colored_postcards.download_color import * +from .util import * +import os +import torch +import pandas as pd +from skimage import io, transform +import numpy as np +import matplotlib.pyplot as plt +from torch.utils.data import Dataset, DataLoader +# from torchvision import transforms, utils + + +class ColDataset(Dataset): + """Colored Postcards dataset""" + + def __init__(self, csv_file, root_dir, transform=None): + """ + :param csv_file (string): Path to the csv file with image names + :param root_dir (string): Directory with all the images + :param transform (callable, optional): Optional transform to be applied on a sample + """ + df = get_raw_data(csv_file) + is_color = df['color'] == 'True' + df_color = df[is_color] + self.id_frame = df_color + self.root_dir = root_dir + self.transform = transform + + def __len__(self): + return len(self.id_frame) + + def __getitem__(self, idx): + if torch.is_tensor(idx): + idx = idx.tolist() + + img_id = self.id_frame.iloc[idx, 0] + img_path = os.path.join(self.root_dir, img_id) + '.jpg' + image = load_img(img_path) + image_lab = color.rgb2lab(image) + tens_l = torch.Tensor(image_lab[:, :, 0])[None, :, :] + tens_ab = torch.Tensor(image_lab[:, :, 1:])[:, :].transpose(2, 1).transpose(1, 0) + sample = {'L': tens_l, 'ab': tens_ab} + + if self.transform: + sample = self.transform(sample) + + return sample + diff --git a/colorization/models/__init__.py b/colorization/models/__init__.py index 058dfb3b46c5c12872d358e89301739e49cdbf18..25d4ea1496d3e6fc330ede0ce42c82ca0d2c5477 100644 --- a/colorization/models/__init__.py +++ b/colorization/models/__init__.py @@ -1,6 +1,5 @@ - from .base_color import * from .eccv16 import * from .siggraph17 import * from .util import * - +from .ColDataset import * diff --git a/colorization/models/colored_postcards/__init__.py b/colorization/models/colored_postcards/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/colorization/models/colored_postcards/download_color.py b/colorization/models/colored_postcards/download_color.py new file mode 100644 index 0000000000000000000000000000000000000000..ba515abba1e6295efe76c16e1ea52d04b7b76ccb --- /dev/null +++ b/colorization/models/colored_postcards/download_color.py @@ -0,0 +1,52 @@ +import pandas as pd +import requests +import jsonpath_ng as jp +import pathlib +import re + + +def get_raw_data(filepath='akon_postcards_public_domain.csv'): + df = pd.read_csv(filepath_or_buffer=filepath, dtype=str) + out_df = df[['akon_id', 'color']] + return out_df + + +jp_image_link = jp.parse('sequences[*].canvases[*].images[*].resource.@id') + + +def get_image_link_for_akon_id(akon_id): + r = requests.get(f'https://iiif.onb.ac.at/presentation/AKON/{akon_id}/manifest') + r.raise_for_status() + links = [m.value for m in jp_image_link.find(r.json())] + if len(links) > 1: + print(f'{len(links)} images found for id {akon_id}.', file=sys.stderr) + return links[0] + + +def download_and_save_image(akon_id, directory): + path = pathlib.Path(f'{directory}/{akon_id}.jpg') + if path.exists(): + return + else: + r = requests.get(get_image_link_for_akon_id(akon_id)) + open(path, 'wb').write(r.content) + + +def dl_save_low_res(akon_id, directory): + path = pathlib.Path(f'{directory}/{akon_id}.jpg') + if path.exists(): + return + else: + link = get_image_link_for_akon_id(akon_id) + low_res_link = re.sub(r'full/full', 'full/256,256', link) + r = requests.get(low_res_link) + open(path, 'wb').write(r.content) + + +if __name__ == '__main__': + df = get_raw_data() + is_color = df['color'] == 'True' + df_color = df[is_color] + print(len(df_color)) + # df_color.apply(lambda x: download_and_save_image(x[0], 'imgs'), axis=1) + df_color.apply(lambda x: dl_save_low_res(x[0], 'low_res_imgs'), axis=1) diff --git a/colorization/models/siggraph17.py b/colorization/models/siggraph17.py index 625a0744e70f8f4186fa5a148277cd71df51c1cf..8c015e54be2563685970d208622bfca4ebc1a178 100644 --- a/colorization/models/siggraph17.py +++ b/colorization/models/siggraph17.py @@ -161,10 +161,9 @@ class SIGGRAPHGenerator(BaseColor): return self.unnormalize_ab(out_reg) -def siggraph17(pretrained=True): +def siggraph17(pretrained=True, model_path='models/siggraph17-df00044c.pth'): model = SIGGRAPHGenerator() if(pretrained): - import torch.utils.model_zoo as model_zoo - model.load_state_dict(model_zoo.load_url('https://colorizers.s3.us-east-2.amazonaws.com/siggraph17-df00044c.pth', - model_dir='./models', map_location='cpu', check_hash=True)) + # import torch.utils.model_zoo as model_zoo + model.load_state_dict(torch.load(model_path, map_location='cpu')) return model diff --git a/colorization/models/util.py b/colorization/models/util.py index 79968ba6b960a8c10047f1ce52400b6bfe766b9c..33c04009dc6529a0522f4ab84b26037ba570d0f2 100644 --- a/colorization/models/util.py +++ b/colorization/models/util.py @@ -6,15 +6,18 @@ import torch import torch.nn.functional as F from IPython import embed + def load_img(img_path): out_np = np.asarray(Image.open(img_path)) if(out_np.ndim==2): out_np = np.tile(out_np[:,:,None],3) return out_np + def resize_img(img, HW=(256,256), resample=3): return np.asarray(Image.fromarray(img).resize((HW[1],HW[0]), resample=resample)) + def preprocess_img(img_rgb_orig, HW=(256,256), resample=3): # return original size L and resized L as torch Tensors img_rgb_rs = resize_img(img_rgb_orig, HW=HW, resample=resample) @@ -30,6 +33,7 @@ def preprocess_img(img_rgb_orig, HW=(256,256), resample=3): return (tens_orig_l, tens_rs_l) + def postprocess_tens(tens_orig_l, out_ab, mode='bilinear'): # tens_orig_l 1 x 1 x H_orig x W_orig # out_ab 1 x 2 x H x W @@ -39,9 +43,24 @@ def postprocess_tens(tens_orig_l, out_ab, mode='bilinear'): # call resize function if needed if(HW_orig[0]!=HW[0] or HW_orig[1]!=HW[1]): - out_ab_orig = F.interpolate(out_ab, size=HW_orig, mode='bilinear') + out_ab_orig = F.interpolate(out_ab, size=HW_orig, mode=mode, align_corners=False) else: out_ab_orig = out_ab out_lab_orig = torch.cat((tens_orig_l, out_ab_orig), dim=1) return color.lab2rgb(out_lab_orig.data.cpu().numpy()[0,...].transpose((1,2,0))) + + +def get_ab_channel(img_rgb_orig, hw=(256, 256), resample=3): + img_rgb_rs = resize_img(img_rgb_orig, HW=hw, resample=resample) + + img_lab_orig = color.rgb2lab(img_rgb_orig) + img_lab_rs = color.rgb2lab(img_rgb_rs) + + img_ab_orig = img_lab_orig[:, :, 1:] + img_ab_rs = img_lab_rs[:, :, 1:] + + tens_orig_ab = torch.Tensor(img_ab_orig)[None, :, :].transpose(3, 2).transpose(2, 1) + tens_rs_ab = torch.Tensor(img_ab_rs)[None, :, :].transpose(3, 2).transpose(2, 1) + + return tens_orig_ab, tens_rs_ab diff --git a/colorization/my_colorize.py b/colorization/my_colorize.py new file mode 100644 index 0000000000000000000000000000000000000000..dad05dcb8b1cf80b419bb9bf948281671341651c --- /dev/null +++ b/colorization/my_colorize.py @@ -0,0 +1,136 @@ +import argparse +import torch +from models import * +import matplotlib.pyplot as plt +import torch.optim as optim + +from torch.utils.data import random_split +from torch.utils.data import DataLoader +import time +import copy +import sys + +parser = argparse.ArgumentParser() +parser.add_argument('-i', '--img_path', type=str, default='imgs/AK044_271.jpg') +parser.add_argument('--use_gpu', action='store_true', help='whether to use GPU') +parser.add_argument('-o', '--save_prefix', type=str, default='imgs_out/AK044_271', help='will save into this file with {eccv16.png, siggraph17.png} suffixes') +# parser.add_argument('-m', '--model', type=str, default='siggraph') +opt = parser.parse_args() +device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + + +def train_model(model, criterion, optimizer, scheduler, num_epochs=25): + since = time.time() + + best_model_wts = copy.deepcopy(model.state_dict()) + best_loss = 100 + + for epoch in range(num_epochs): + print('Epoch {}/{}'.format(epoch, num_epochs - 1)) + print('-' * 10) + + # Each epoch has a training and validation phase + for phase in ['train', 'test']: + if phase == 'train': + model.train() # Set model to training mode + else: + model.eval() # Set model to evaluate mode + + running_loss = 0.0 + # running_corrects = 0 + + # Iterate over data. + for i, data in enumerate(dataloaders[phase]): + inputs = data['L'].to(device) + labels = data['ab'].to(device) + + # zero the parameter gradients + optimizer.zero_grad() + + # forward + # track history if only in train + with torch.set_grad_enabled(phase == 'train'): + outputs = model(inputs) + # _, preds = torch.max(outputs, 1) + loss = criterion(outputs, labels) + + # backward + optimize only if in training phase + if phase == 'train': + loss.backward() + optimizer.step() + + # statistics + running_loss += loss.item() * inputs.size(0) + sys.stdout.write('\r' 'Id ' + str(i) + ', Running loss: ' + str(running_loss)) + if ((i == 5) and (phase == 'test')) or (i == 20): + break + + # running_corrects += torch.sum(preds == labels.data) + if phase == 'train': + scheduler.step() + + epoch_loss = running_loss / dataset_sizes[phase] + # epoch_acc = running_corrects.double() / dataset_sizes[phase] + + print('{} Loss: {:.4f}'.format( + phase, epoch_loss)) + + # deep copy the model + if phase == 'test' and epoch_loss < best_loss: + best_loss = epoch_loss + best_model_wts = copy.deepcopy(model.state_dict()) + + print() + + time_elapsed = time.time() - since + print('Training complete in {:.0f}m {:.0f}s'.format( + time_elapsed // 60, time_elapsed % 60)) + print('Best val Loss: {:4f}'.format(best_loss)) + + # load best model weights + model.load_state_dict(best_model_wts) + return model + + +# load colorizer +colorizer = siggraph17(pretrained=False) +postcard_dataset = ColDataset(csv_file='models/colored_postcards/akon_postcards_public_domain.csv', + root_dir='models/colored_postcards/low_res_imgs') +(train_set, test_set) = random_split(postcard_dataset, [round(0.7*len(postcard_dataset)), + round(0.3*len(postcard_dataset))]) +train_loader = DataLoader(train_set, shuffle=True) +test_loader = DataLoader(test_set) +dataloaders = {'train': train_loader, 'test': test_loader} +dataset_sizes = {'train': len(train_loader), 'test': len(test_loader)} +criterion = nn.SmoothL1Loss() +optimizer = optim.Adam(colorizer.parameters(), lr=0.001) +scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=7, gamma=0.1) + +colorizer = train_model(colorizer, criterion, optimizer, scheduler) +PATH = './models/col_test.pth' +torch.save(colorizer.state_dict(), PATH) + +# +# img = load_img('models/colored_postcards/imgs/AK001_011.jpg') +# low_img = load_img('models/colored_postcards/low_res_imgs/AK001_011.jpg') +# (tens_l_orig, tens_l_rs) = preprocess_img(img, HW=(256, 256)) +# (t_l_orig, t_l_rs) = preprocess_img(low_img, HW=(256, 256)) +# (tens_ab_orig, tens_ab_rs) = get_ab_channel(img, hw=(256, 256)) +# +# colorizer_siggraph17.train() +# target = tens_ab_rs +# criterion = nn.SmoothL1Loss() +# optimizer = optim.Adam(colorizer_siggraph17.parameters(), lr=0.01) +# +# # in the training loop: +# optimizer.zero_grad() +# output = colorizer_siggraph17(t_l_rs) +# loss = criterion(output, target) +# print(loss) +# loss.backward() +# optimizer.step() +# +# img_bw = postprocess_tens(tens_l_orig, torch.cat((0*tens_l_orig, 0*tens_l_orig), dim=1)) +# out_img_siggraph17 = postprocess_tens(tens_l_orig, output) +# # plt.imsave('test_bw.jpg', img_bw) +# plt.imsave('test_output.jpg', out_img_siggraph17) diff --git a/iiif_tools.py b/iiif_tools.py index 83707e7e442648fb6f39c66c5416f7451f660506..e82eabc34f1b1ee5b89c16dde5e3a5c625d758c2 100644 --- a/iiif_tools.py +++ b/iiif_tools.py @@ -9,6 +9,7 @@ class IIIFAbstractManifest: A class for handling IIIF manifests for further processing. """ def __init__(self, manifest_url): + self.manifest_url = manifest_url self.manifest = requests.get(manifest_url).json() @property diff --git a/requirements.txt b/requirements.txt index 732cf12fa75a2765162e2f077d02e83e9a4cf243..53c2db9cd48457adf4d04db54076b0caeaad26d2 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,3 +1,11 @@ opencv-python numpy -requests \ No newline at end of file +requests +torch +scikit-image +matplotlib +argparse +pandas +pillow +ipython +jsonpath-ng diff --git a/stitching/imgs_out/crop_stitch_all.jpg b/stitching/imgs_out/crop_stitch_all.jpg index 76dfe709145fdcf938e9fdd259f3f28a29c40dca..ca9f23cbf22853a5ac9e69bba7265daec5c5164d 100644 Binary files a/stitching/imgs_out/crop_stitch_all.jpg and b/stitching/imgs_out/crop_stitch_all.jpg differ