Files
gpu_colorization/Plots/CreatePlot_showImages.py
2022-03-30 17:01:33 +02:00

84 lines
2.6 KiB
Python

import tensorflow as tf
import tensorflow_datasets as tfds
import matplotlib.pyplot as plt
import numpy as np
import os
import sys
sys.path.append("..")
from Autoencoder import Autoencoder
from Training import prepare_data, getRGB
def main():
labels_path = tf.keras.utils.get_file('ImageNetLabels.txt','https://storage.googleapis.com/download.tensorflow.org/data/ImageNetLabels.txt')
imagenet_labels = np.array(open(labels_path).read().splitlines())
data_dir = '/home/timwitte/Downloads/'
write_dir = '../imagenet'
# Construct a tf.data.Dataset
download_config = tfds.download.DownloadConfig(
extract_dir=os.path.join(write_dir, 'extracted'),
manual_dir=data_dir
)
download_and_prepare_kwargs = {
'download_dir': os.path.join(write_dir, 'downloaded'),
'download_config': download_config,
}
train_dataset, test_dataset= tfds.load('imagenet2012',
data_dir=os.path.join(write_dir, 'data'),
split=['train', 'validation'],
shuffle_files=True,
download=True,
as_supervised=True,
download_and_prepare_kwargs=download_and_prepare_kwargs)
test_dataset = test_dataset.take(32).apply(prepare_data)
autoencoder = Autoencoder()
autoencoder.build((1, 256, 256, 1)) # need a batch size
autoencoder.load_weights("../saved_models/trainied_weights_epoch_12")
autoencoder.summary()
autoencoder.encoder.summary()
autoencoder.decoder.summary()
for img_L, img_AB_orginal in test_dataset.take(1):
img_AB_reconstructed = autoencoder(img_L)
img_rgb_orginal = getRGB(img_L, img_AB_orginal)
img_rgb_reconstructed = getRGB(img_L, img_AB_reconstructed)
NUM_IMGS = 5
fig, axs = plt.subplots(NUM_IMGS, 3)
axs[0, 0].set_title("Input", fontsize=30)
axs[0, 1].set_title("Output", fontsize=30)
axs[0, 2].set_title("Ground Truth", fontsize=30)
for i in range(NUM_IMGS):
axs[i, 0].imshow(img_L[i], cmap="gray")
axs[i, 0].set_axis_off()
axs[i, 1].imshow(img_rgb_reconstructed[i])
axs[i, 1].set_axis_off()
axs[i, 2].imshow(img_rgb_orginal[i])
axs[i, 2].set_axis_off()
plt.tight_layout()
fig.set_size_inches(15, 25)
fig.savefig("ColoredImages.png")
if __name__ == "__main__":
try:
main()
except KeyboardInterrupt:
print("KeyboardInterrupt received")