#!/usr/bin/env python
from __future__ import (absolute_import, division, print_function)
from PIL import Image, ImageFile
import sys
import os
import shutil
import errno
from imagehash import average_hash, phash, dhash, whash
import collections
import multiprocessing
import time
import math
import argparse
from argparse import RawTextHelpFormatter
ImageFile.LOAD_TRUNCATED_IMAGES = True


def ensure_folder(path):
    try:
        os.makedirs(path)
    except OSError as exception:
        if exception.errno != errno.EEXIST:
            raise


def image_ahash(image_path):
    try:
        with Image.open(image_path) as image:
            hash = average_hash(image)
    except Exception:
        return (image_path, None)
    return (image_path, hash)

def image_phash(image_path):
    try:
        with Image.open(image_path) as image:
            hash = phash(image)
    except Exception:
        return (image_path, None)
    return (image_path, hash)

def image_dhash(image_path):
    try:
        with Image.open(image_path) as image:
            hash = dhash(image)
    except Exception:
        return (image_path, None)
    return (image_path, hash)

def image_whash_haar(image_path):
    try:
        with Image.open(image_path) as image:
            hash = whash(image)
    except Exception:
        return (image_path, None)
    return (image_path, hash)

def image_whash_db4(image_path):
    try:
        with Image.open(image_path) as image:
            hash = whash(image, mode='db4')
    except Exception:
        return (image_path, None)
    return (image_path, hash)

def is_image(filename):
    f = filename.lower()
    return f.endswith(".png") or f.endswith(".jpg") or f.endswith(".jpeg") or f.endswith(".bmp") or f.endswith(".gif")

if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='Find similar images in the input folder and move them to the output folder', formatter_class=RawTextHelpFormatter)

    hash_dict = { 'ahash': image_ahash,
                  'phash': image_phash,
                  'dhash': image_dhash,
                  'whash-haar': image_whash_haar,
                  'whash-db4': image_whash_db4
    }

    help = """Method:
    ahash:  Average hash
    phash:  Perceptual hash
    dhash:  Difference hash
    whash-haar: Haar wavelet hash
    whash-db4:  Daubechies wavelet hash"""

    parser.add_argument("-hm", type=str,
                        choices = hash_dict.keys(),
                        default="ahash",
                        help = help
    )

    parser.add_argument('input_folder', type=str,
                        help='input folder')
    parser.add_argument('output_folder', type=str,
                        help='output folder')

    parser.add_argument("-now", type=int,
                        help='number of subproccesses for hashing', default=None)

    args = parser.parse_args()

    folder_in = os.path.abspath(args.input_folder)
    folder_out = os.path.abspath(args.output_folder)
    image_hashfunc = hash_dict[args.hm]
    poolsize = args.now

    if os.path.exists(folder_out):
        print("output folder exists!")
        print("Delete the folder before proceding")
        exit(1)

    print("Searching input folder:", folder_in)
    image_filenames = tuple([os.path.join(root, file) for root, dirs, files in os.walk(folder_in) for file in files if is_image(file)])

    total = len(image_filenames)
    print("Number of found images:", total)

    print("Hashing now...")
    images = collections.defaultdict(list)
    total = len(image_filenames)
    p = multiprocessing.Pool(poolsize)
    rs = p.imap_unordered(image_hashfunc, image_filenames, chunksize = 100)
    p.close()

    bar_length=50
    old = None
    for n, (image_path, hash) in enumerate(rs):
        images[hash].append(image_path)
        percents = round(100.0 * n / float(total), 1)
        if not percents == old:
            filled_len = int(round(bar_length * n / total))
            bar = '=' * filled_len + '-' * (bar_length - filled_len)
            sys.stdout.write('[%s] %s\r' % (bar, str(percents) + "%"))
            sys.stdout.flush()
        old = percents
    sys.stdout.write('[%s] %s' % ('=' * bar_length, "100%"))
    sys.stdout.flush()


    #are there any files which could not be hashed?
    if len(images[None]) > 0:
        print("")
        print("Problematic files:")
        for image in images[None]:
            print(image)
        del images[None]


    similar_files = 0
    for img_list in images.values():
        tmp = len(img_list)
        similar_files += 0 if tmp == 1 else tmp

    print("")
    if similar_files:
        print("Number of similar files:", similar_files)
        print("Moving to output folder")
        for k, img_list in images.items():
            if len(img_list) > 1:
                for img in img_list:
                    (path, _) = os.path.split(img)
                    to_folder = os.path.join(folder_out, path.split(":")[-1][1:])
                    ensure_folder(to_folder)
                    shutil.move(img, to_folder)
    else:
        print("No similar pictures found")
    print("All Done!")

Published

Category

snippets

Tags