Files
immich_duplicate_finder-mirror/imageDuplicate.py
Deathproof76 3522f38b0f Fix convert RGBA to RGB image before Resnet Evaluation in imageDuplicate.py
Possible fix for https://github.com/vale46n1/immich_duplicate_finder/issues/32

ResNet does seems to have a problem with RGBA encoded images. This might help by converting the image to RGB for duplicate evaluation.
2024-05-06 13:15:51 +02:00

274 lines
11 KiB
Python

import os
import streamlit as st
import time
import torch
import numpy as np
import faiss
from torchvision.models import resnet152, ResNet152_Weights
from torchvision.transforms import Compose, Resize, ToTensor, Normalize
from PIL import Image
from api import getImage
from utility import display_asset_column
from api import getAssetInfo
from db import load_duplicate_pairs, is_db_populated, save_duplicate_pair
from streamlit_image_comparison import image_comparison
# Set the environment variable to allow multiple OpenMP libraries
os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"
# Load ResNet152 with pretrained weights
model = resnet152(weights=ResNet152_Weights.DEFAULT)
model.eval() # Set model to evaluation mode
def convert_image_to_rgb(image):
"""Convert image to RGB if it's RGBA."""
if image.mode == 'RGBA':
return image.convert('RGB')
return image
transform = Compose([
convert_image_to_rgb,
Resize((224, 224)), # Standard size for ImageNet-trained models
ToTensor(),
Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
# Global variables for paths
index_path = 'faiss_index.bin'
metadata_path = 'metadata.npy'
def extract_features(image):
"""Extract features from an image using a pretrained model."""
image_tensor = transform(image).unsqueeze(0) # Add batch dimension
with torch.no_grad():
features = model(image_tensor)
return features.numpy().flatten()
def init_or_load_faiss_index():
"""Initialize or load the FAISS index and metadata, ensuring index is ready for use."""
if os.path.exists(index_path) and os.path.exists(metadata_path):
index = faiss.read_index(index_path)
metadata = np.load(metadata_path, allow_pickle=True).tolist()
else:
index = None
metadata = []
return index, metadata
def save_faiss_index_and_metadata(index, metadata):
"""Save the FAISS index and metadata to disk."""
faiss.write_index(index, index_path)
np.save(metadata_path, np.array(metadata, dtype=object))
def update_faiss_index(immich_server_url,api_key, asset_id):
"""Update the FAISS index and metadata with a new image and its ID,
skipping if the asset_id has already been processed."""
global index # Assuming index is defined globally
index, existing_metadata = init_or_load_faiss_index()
# Check if the asset_id is already in metadata to decide whether to skip processing
if asset_id in existing_metadata:
return 'skipped' # Skip processing this image
image = getImage(asset_id, immich_server_url, "Thumbnail (fast)", api_key)
if image is not None:
features = extract_features(image)
else:
return 'error'
if index is None:
# Initialize the FAISS index with the correct dimension if it's the first time
dimension = features.shape[0]
index = faiss.IndexFlatL2(dimension)
index.add(np.array([features], dtype='float32'))
existing_metadata.append(asset_id)
save_faiss_index_and_metadata(index, existing_metadata)
return 'processed'
def calculateFaissIndex(assets, immich_server_url, api_key):
# Initialize session state variables if they are not already set
if 'message' not in st.session_state:
st.session_state['message'] = ""
if 'progress' not in st.session_state:
st.session_state['progress'] = 0
if 'stop_index' not in st.session_state:
st.session_state['stop_index'] = False
# Set up the UI components
progress_bar = st.progress(st.session_state['progress'])
stop_button = st.button('Stop Index Processing')
message_placeholder = st.empty()
# Check if stop was requested and reset it if button is pressed
if stop_button:
st.session_state['stop_index'] = True
st.session_state['calculate_faiss'] = False
total_assets = len(assets)
processed_assets = 0
skipped_assets = 0
error_assets = 0
total_time = 0
for i, asset in enumerate(assets):
if st.session_state['stop_index']:
st.session_state['message'] = "Processing stopped by user."
message_placeholder.text(st.session_state['message'])
break # Break the loop if stop is requested
asset_id = asset.get('id')
start_time = time.time()
status = update_faiss_index(immich_server_url,api_key, asset_id)
if status == 'processed':
processed_assets += 1
elif status == 'skipped':
skipped_assets += 1
elif status == 'error':
error_assets += 1
end_time = time.time()
processing_time = end_time - start_time
total_time += processing_time
# Update progress and messages
progress_percentage = (i + 1) / total_assets
st.session_state['progress'] = progress_percentage
progress_bar.progress(progress_percentage)
estimated_time_remaining = (total_time / (i + 1)) * (total_assets - (i + 1))
estimated_time_remaining_min = int(estimated_time_remaining / 60)
st.session_state['message'] = f"Processing asset {i + 1}/{total_assets} - (Processed: {processed_assets}, Skipped: {skipped_assets}, Errors: {error_assets}). Estimated time remaining: {estimated_time_remaining_min} minutes."
message_placeholder.text(st.session_state['message'])
# Reset stop flag at the end of processing
st.session_state['stop_index'] = False
if processed_assets >= total_assets:
st.session_state['message'] = "Processing complete!"
message_placeholder.text(st.session_state['message'])
progress_bar.progress(1.0)
def generate_db_duplicate():
st.write("Database initialization")
index, metadata = init_or_load_faiss_index()
if not index or not metadata:
st.write("FAISS index or metadata not available.")
return
# Check and update the stop mechanism in session state
if 'stop_requested' not in st.session_state:
st.session_state['stop_requested'] = False
# Button to request stopping
if st.button('Stop Finding Duplicates'):
st.session_state['stop_requested'] = True
st.session_state['generate_db_duplicate'] = False
num_vectors = index.ntotal
message_placeholder = st.empty()
progress_bar = st.progress(0)
for i in range(num_vectors):
# Check if stop has been requested
if st.session_state['stop_requested']:
message_placeholder.text("Processing was stopped by the user.")
progress_bar.empty()
# Optionally, reset the stop flag here if you want the process to be restartable without refreshing the page
st.session_state['stop_requested'] = False
return None
progress = (i + 1) / num_vectors
message_placeholder.text(f"Finding duplicates: processing vector {i+1} of {num_vectors}")
progress_bar.progress(progress)
query_vector = np.array([index.reconstruct(i)])
distances, indices = index.search(query_vector, 2)
for j in range(1, indices.shape[1]):
#if distances[0][j] < threshold:
idx1, idx2 = i, indices[0][j]
if idx1 != idx2:
sorted_pair = (min(idx1, idx2), max(idx1, idx2))
# Check if the indices in sorted_pair are within the bounds of metadata
if sorted_pair[0] < len(metadata) and sorted_pair[1] < len(metadata):
save_duplicate_pair(metadata[sorted_pair[0]], metadata[sorted_pair[1]], distances[0][j])
else:
st.error(f"Metadata index out of range: {sorted_pair}")
# Optionally log more details or handle this case further
message_placeholder.text(f"Finished processing {num_vectors} vectors.")
progress_bar.empty()
def show_duplicate_photos_faiss(assets, limit, min_threshold, max_threshold,immich_server_url,api_key):
# First check if the database is populated
if not is_db_populated():
st.write("The database does not contain any duplicate entries. Please generate/update the database.")
return # Exit the function early if the database is not populated
# Load duplicates from database
duplicates = load_duplicate_pairs(min_threshold, max_threshold)
if duplicates:
st.write(f"Found {len(duplicates)} duplicate pairs with FAISS code within threshold {min_threshold} < x < {max_threshold}:")
progress_bar = st.progress(0)
num_duplicates_to_show = min(len(duplicates), limit)
for i, dup_pair in enumerate(duplicates[:num_duplicates_to_show]):
try:
# Check if stop was requested
if st.session_state.get('stop_requested', False):
st.write("Processing was stopped by the user.")
st.session_state['stop_requested'] = False # Reset the flag for future operations
st.session_state['generate_db_duplicate'] = False
break # Exit the loop
progress = (i + 1) / num_duplicates_to_show
progress_bar.progress(progress)
asset_id_1, asset_id_2 = dup_pair
image1 = getImage(asset_id_1, immich_server_url, 'Thumbnail (fast)', api_key)
image2 = getImage(asset_id_2, immich_server_url, 'Thumbnail (fast)', api_key)
asset1_info = getAssetInfo(asset_id_1, assets)
asset2_info = getAssetInfo(asset_id_2, assets)
if image1 is not None and image2 is not None:
# Convert PIL images to numpy arrays if necessary
image1 = np.array(image1)
image2 = np.array(image2)
# Proceed with image comparison
image_comparison(
img1=image1,
img2=image2,
label1=f"Name: {asset_id_1}",
label2=f"Name: {asset_id_2}",
width=700,
starting_position=50,
show_labels=True,
make_responsive=True,
in_memory=False,
)
col1, col2 = st.columns(2)
# with col1:
# st.image(image1, caption=f"Name: {asset_id_1}")
# with col2:
# st.image(image2, caption=f"Name: {asset_id_2}")
display_asset_column(col1, asset1_info, asset2_info, asset_id_1,asset_id_2, immich_server_url, api_key)
display_asset_column(col2, asset2_info, asset1_info, asset_id_2,asset_id_1, immich_server_url, api_key)
else:
st.write(f"Missing information for one or both assets: {asset_id_1}, {asset_id_2}")
st.markdown("---")
except:
st.write(f"Missing information for one or both assets")
progress_bar.progress(100)
else:
st.write("No duplicates found.")