diff --git a/README.md b/README.md index 6c83fd3..0f91661 100644 --- a/README.md +++ b/README.md @@ -112,6 +112,19 @@ To evaluate a model, use the following template: python validate.py --arch=CLIP:ViT-L/14 --ckpt=path/to/the/saved/mode/checkpoint/model_epoch_best.pth --result_folder=path/to/save/the/results --fully_supervised ``` +To output the images as well e.g.: +```bash +python validate.py --arch=CLIP:ViT-L/14 --ckpt=models/vit-20.pth --result_folder=results --fully_supervised --output_save_path=output-images +``` + +To generate the grids of images for a dataset e.g. lama.: +```bash +python generate_image_grids.py lama +``` + +Image grids will be save by default to `./output-grids/[dataset]` + + ## License

The code is licensed under CC BY-NC-SA 4.0

diff --git a/generate_image_grids.py b/generate_image_grids.py new file mode 100644 index 0000000..cf6db73 --- /dev/null +++ b/generate_image_grids.py @@ -0,0 +1,113 @@ +import os +import argparse +import matplotlib.pyplot as plt +import matplotlib.image as mpimg +import matplotlib.gridspec as gridspec + + +def plot_image_rows(real_image_path, fake_images, mask_images, output_images, output_path): + """ + Plots three rows of images for each variation with larger column headings and saves to output_path. + + Args: + real_image_path (str): Path to the real image. + fake_images (list of str): Paths to the fake images (3 variations). + mask_images (list of str): Paths to the mask images (3 variations). + output_images (list of str): Paths to the output images (3 variations). + output_path (str): Path to save the rows of images. + """ + num_rows = len(fake_images) # Should be 3 + num_cols = 4 # Real, Fake (Inpainted), Mask, Output (DeCLIP) + column_titles = ["Real", "Inpainted", "Mask", "DeCLIP"] + + # Set up the figure with a tighter layout + fig = plt.figure(figsize=(num_cols * 4, (num_rows + 0.5) * 4)) + spec = gridspec.GridSpec(num_rows + 1, num_cols, figure=fig, height_ratios=[0.2] + [1] * num_rows) + + # Add column headings with larger font size + for i, title in enumerate(column_titles): + ax = fig.add_subplot(spec[0, i]) + ax.text(0.5, 0.5, title, fontsize=24, ha="center", va="center") # Increased font size + ax.axis("off") + + # Adjust spacing to make headings closer to images + plt.subplots_adjust(top=0.9, hspace=0.1) + + # Add images in subsequent rows + for i in range(num_rows): + images = [ + mpimg.imread(real_image_path), + mpimg.imread(fake_images[i]), + mpimg.imread(mask_images[i]), + mpimg.imread(output_images[i]), + ] + + for j, img in enumerate(images): + ax = fig.add_subplot(spec[i + 1, j]) + ax.imshow(img) + ax.axis("off") + + # Save the plot as an image + plt.tight_layout(pad=0.5) + plt.savefig(output_path) + plt.close() + + +def process_dataset(dataset, real_dir, fake_dir_template, mask_dir_template, output_dir_template, output_rows_dir): + """ + Processes the dataset, creating rows of images. + + Args: + dataset (str): Dataset name. + real_dir (str): Directory of real images. + fake_dir_template (str): Template for fake image directory (contains [dataset]). + mask_dir_template (str): Template for mask directory (contains [dataset]). + output_dir_template (str): Template for output image directory (contains [dataset]). + output_rows_dir (str): Directory to save the output rows of images. + """ + # Update paths with the dataset name + fake_dir = fake_dir_template.replace("[dataset]", dataset) + mask_dir = mask_dir_template.replace("[dataset]", dataset) + output_dir = output_dir_template.replace("[dataset]", dataset) + output_rows_dataset_dir = os.path.join(output_rows_dir, dataset) + + # Create output directory if it doesn't exist + os.makedirs(output_rows_dataset_dir, exist_ok=True) + + # Process each real image + for real_image_name in sorted(os.listdir(real_dir)): + if real_image_name.endswith(".png"): + real_image_path = os.path.join(real_dir, real_image_name) + + # Generate fake, mask, and output image paths + base_name = os.path.splitext(real_image_name)[0] + fake_images = [os.path.join(fake_dir, f"{base_name}-{i}.png") for i in range(3)] + mask_images = [os.path.join(mask_dir, f"{base_name}-{i}.png") for i in range(3)] + output_images = [os.path.join(output_dir, f"{base_name}-{i}.png") for i in range(3)] + + # Ensure all files exist + if all(os.path.exists(p) for p in fake_images + mask_images + output_images): + # Save the rows of images + output_row_path = os.path.join(output_rows_dataset_dir, f"{base_name}_grid.png") + plot_image_rows(real_image_path, fake_images, mask_images, output_images, output_row_path) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Generate rows of images for comparison.") + parser.add_argument("dataset", type=str, help="Dataset name to process.") + parser.add_argument("--real_dir", type=str, default="./datasets/dolos_data/celebahq/real/test", + help="Path to the directory of real images.") + parser.add_argument("--fake_dir_template", type=str, + default="./datasets/dolos_data/celebahq/fake/[dataset]/images/test", + help="Template path for fake image directory, containing [dataset].") + parser.add_argument("--mask_dir_template", type=str, + default="./datasets/dolos_data/celebahq/fake/[dataset]/masks/test", + help="Template path for mask directory, containing [dataset].") + parser.add_argument("--output_dir_template", type=str, + default="./output-images/[dataset]", + help="Template path for output image directory, containing [dataset].") + parser.add_argument("--output_rows_dir", type=str, default="output-grids", + help="Directory to save the output rows of images.") + + args = parser.parse_args() + process_dataset(args.dataset, args.real_dir, args.fake_dir_template, args.mask_dir_template, args.output_dir_template, args.output_rows_dir)