mirror of
https://github.com/vale46n1/immich_duplicate_finder.git
synced 2025-12-13 20:35:46 +01:00
Possible fix for https://github.com/vale46n1/immich_duplicate_finder/issues/32 ResNet does seems to have a problem with RGBA encoded images. This might help by converting the image to RGB for duplicate evaluation.
274 lines
11 KiB
Python
274 lines
11 KiB
Python
import os
|
|
import streamlit as st
|
|
import time
|
|
|
|
import torch
|
|
import numpy as np
|
|
import faiss
|
|
from torchvision.models import resnet152, ResNet152_Weights
|
|
from torchvision.transforms import Compose, Resize, ToTensor, Normalize
|
|
from PIL import Image
|
|
|
|
from api import getImage
|
|
from utility import display_asset_column
|
|
from api import getAssetInfo
|
|
from db import load_duplicate_pairs, is_db_populated, save_duplicate_pair
|
|
from streamlit_image_comparison import image_comparison
|
|
|
|
# Set the environment variable to allow multiple OpenMP libraries
|
|
os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"
|
|
|
|
# Load ResNet152 with pretrained weights
|
|
model = resnet152(weights=ResNet152_Weights.DEFAULT)
|
|
model.eval() # Set model to evaluation mode
|
|
|
|
def convert_image_to_rgb(image):
|
|
"""Convert image to RGB if it's RGBA."""
|
|
if image.mode == 'RGBA':
|
|
return image.convert('RGB')
|
|
return image
|
|
|
|
transform = Compose([
|
|
convert_image_to_rgb,
|
|
Resize((224, 224)), # Standard size for ImageNet-trained models
|
|
ToTensor(),
|
|
Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
|
|
])
|
|
|
|
# Global variables for paths
|
|
index_path = 'faiss_index.bin'
|
|
metadata_path = 'metadata.npy'
|
|
|
|
def extract_features(image):
|
|
"""Extract features from an image using a pretrained model."""
|
|
image_tensor = transform(image).unsqueeze(0) # Add batch dimension
|
|
with torch.no_grad():
|
|
features = model(image_tensor)
|
|
return features.numpy().flatten()
|
|
|
|
def init_or_load_faiss_index():
|
|
"""Initialize or load the FAISS index and metadata, ensuring index is ready for use."""
|
|
if os.path.exists(index_path) and os.path.exists(metadata_path):
|
|
index = faiss.read_index(index_path)
|
|
metadata = np.load(metadata_path, allow_pickle=True).tolist()
|
|
else:
|
|
index = None
|
|
metadata = []
|
|
return index, metadata
|
|
|
|
def save_faiss_index_and_metadata(index, metadata):
|
|
"""Save the FAISS index and metadata to disk."""
|
|
faiss.write_index(index, index_path)
|
|
np.save(metadata_path, np.array(metadata, dtype=object))
|
|
|
|
def update_faiss_index(immich_server_url,api_key, asset_id):
|
|
|
|
"""Update the FAISS index and metadata with a new image and its ID,
|
|
skipping if the asset_id has already been processed."""
|
|
global index # Assuming index is defined globally
|
|
index, existing_metadata = init_or_load_faiss_index()
|
|
|
|
# Check if the asset_id is already in metadata to decide whether to skip processing
|
|
if asset_id in existing_metadata:
|
|
return 'skipped' # Skip processing this image
|
|
|
|
image = getImage(asset_id, immich_server_url, "Thumbnail (fast)", api_key)
|
|
if image is not None:
|
|
features = extract_features(image)
|
|
else:
|
|
return 'error'
|
|
|
|
if index is None:
|
|
# Initialize the FAISS index with the correct dimension if it's the first time
|
|
dimension = features.shape[0]
|
|
index = faiss.IndexFlatL2(dimension)
|
|
|
|
index.add(np.array([features], dtype='float32'))
|
|
existing_metadata.append(asset_id)
|
|
|
|
save_faiss_index_and_metadata(index, existing_metadata)
|
|
return 'processed'
|
|
|
|
def calculateFaissIndex(assets, immich_server_url, api_key):
|
|
# Initialize session state variables if they are not already set
|
|
if 'message' not in st.session_state:
|
|
st.session_state['message'] = ""
|
|
if 'progress' not in st.session_state:
|
|
st.session_state['progress'] = 0
|
|
if 'stop_index' not in st.session_state:
|
|
st.session_state['stop_index'] = False
|
|
|
|
# Set up the UI components
|
|
progress_bar = st.progress(st.session_state['progress'])
|
|
stop_button = st.button('Stop Index Processing')
|
|
message_placeholder = st.empty()
|
|
|
|
# Check if stop was requested and reset it if button is pressed
|
|
if stop_button:
|
|
st.session_state['stop_index'] = True
|
|
st.session_state['calculate_faiss'] = False
|
|
|
|
total_assets = len(assets)
|
|
processed_assets = 0
|
|
skipped_assets = 0
|
|
error_assets = 0
|
|
total_time = 0
|
|
|
|
for i, asset in enumerate(assets):
|
|
if st.session_state['stop_index']:
|
|
st.session_state['message'] = "Processing stopped by user."
|
|
message_placeholder.text(st.session_state['message'])
|
|
break # Break the loop if stop is requested
|
|
|
|
asset_id = asset.get('id')
|
|
start_time = time.time()
|
|
|
|
status = update_faiss_index(immich_server_url,api_key, asset_id)
|
|
if status == 'processed':
|
|
processed_assets += 1
|
|
elif status == 'skipped':
|
|
skipped_assets += 1
|
|
elif status == 'error':
|
|
error_assets += 1
|
|
|
|
end_time = time.time()
|
|
processing_time = end_time - start_time
|
|
total_time += processing_time
|
|
|
|
# Update progress and messages
|
|
progress_percentage = (i + 1) / total_assets
|
|
st.session_state['progress'] = progress_percentage
|
|
progress_bar.progress(progress_percentage)
|
|
estimated_time_remaining = (total_time / (i + 1)) * (total_assets - (i + 1))
|
|
estimated_time_remaining_min = int(estimated_time_remaining / 60)
|
|
|
|
st.session_state['message'] = f"Processing asset {i + 1}/{total_assets} - (Processed: {processed_assets}, Skipped: {skipped_assets}, Errors: {error_assets}). Estimated time remaining: {estimated_time_remaining_min} minutes."
|
|
message_placeholder.text(st.session_state['message'])
|
|
|
|
# Reset stop flag at the end of processing
|
|
st.session_state['stop_index'] = False
|
|
if processed_assets >= total_assets:
|
|
st.session_state['message'] = "Processing complete!"
|
|
message_placeholder.text(st.session_state['message'])
|
|
progress_bar.progress(1.0)
|
|
|
|
def generate_db_duplicate():
|
|
st.write("Database initialization")
|
|
index, metadata = init_or_load_faiss_index()
|
|
if not index or not metadata:
|
|
st.write("FAISS index or metadata not available.")
|
|
return
|
|
|
|
# Check and update the stop mechanism in session state
|
|
if 'stop_requested' not in st.session_state:
|
|
st.session_state['stop_requested'] = False
|
|
|
|
# Button to request stopping
|
|
if st.button('Stop Finding Duplicates'):
|
|
st.session_state['stop_requested'] = True
|
|
st.session_state['generate_db_duplicate'] = False
|
|
|
|
num_vectors = index.ntotal
|
|
message_placeholder = st.empty()
|
|
progress_bar = st.progress(0)
|
|
|
|
for i in range(num_vectors):
|
|
# Check if stop has been requested
|
|
if st.session_state['stop_requested']:
|
|
message_placeholder.text("Processing was stopped by the user.")
|
|
progress_bar.empty()
|
|
# Optionally, reset the stop flag here if you want the process to be restartable without refreshing the page
|
|
st.session_state['stop_requested'] = False
|
|
return None
|
|
|
|
progress = (i + 1) / num_vectors
|
|
message_placeholder.text(f"Finding duplicates: processing vector {i+1} of {num_vectors}")
|
|
progress_bar.progress(progress)
|
|
|
|
query_vector = np.array([index.reconstruct(i)])
|
|
distances, indices = index.search(query_vector, 2)
|
|
|
|
for j in range(1, indices.shape[1]):
|
|
#if distances[0][j] < threshold:
|
|
idx1, idx2 = i, indices[0][j]
|
|
if idx1 != idx2:
|
|
sorted_pair = (min(idx1, idx2), max(idx1, idx2))
|
|
# Check if the indices in sorted_pair are within the bounds of metadata
|
|
if sorted_pair[0] < len(metadata) and sorted_pair[1] < len(metadata):
|
|
save_duplicate_pair(metadata[sorted_pair[0]], metadata[sorted_pair[1]], distances[0][j])
|
|
else:
|
|
st.error(f"Metadata index out of range: {sorted_pair}")
|
|
# Optionally log more details or handle this case further
|
|
|
|
message_placeholder.text(f"Finished processing {num_vectors} vectors.")
|
|
progress_bar.empty()
|
|
|
|
def show_duplicate_photos_faiss(assets, limit, min_threshold, max_threshold,immich_server_url,api_key):
|
|
# First check if the database is populated
|
|
if not is_db_populated():
|
|
st.write("The database does not contain any duplicate entries. Please generate/update the database.")
|
|
return # Exit the function early if the database is not populated
|
|
|
|
# Load duplicates from database
|
|
duplicates = load_duplicate_pairs(min_threshold, max_threshold)
|
|
|
|
if duplicates:
|
|
st.write(f"Found {len(duplicates)} duplicate pairs with FAISS code within threshold {min_threshold} < x < {max_threshold}:")
|
|
progress_bar = st.progress(0)
|
|
num_duplicates_to_show = min(len(duplicates), limit)
|
|
|
|
for i, dup_pair in enumerate(duplicates[:num_duplicates_to_show]):
|
|
try:
|
|
# Check if stop was requested
|
|
if st.session_state.get('stop_requested', False):
|
|
st.write("Processing was stopped by the user.")
|
|
st.session_state['stop_requested'] = False # Reset the flag for future operations
|
|
st.session_state['generate_db_duplicate'] = False
|
|
break # Exit the loop
|
|
|
|
progress = (i + 1) / num_duplicates_to_show
|
|
progress_bar.progress(progress)
|
|
|
|
asset_id_1, asset_id_2 = dup_pair
|
|
|
|
image1 = getImage(asset_id_1, immich_server_url, 'Thumbnail (fast)', api_key)
|
|
image2 = getImage(asset_id_2, immich_server_url, 'Thumbnail (fast)', api_key)
|
|
asset1_info = getAssetInfo(asset_id_1, assets)
|
|
asset2_info = getAssetInfo(asset_id_2, assets)
|
|
|
|
if image1 is not None and image2 is not None:
|
|
# Convert PIL images to numpy arrays if necessary
|
|
image1 = np.array(image1)
|
|
image2 = np.array(image2)
|
|
# Proceed with image comparison
|
|
image_comparison(
|
|
img1=image1,
|
|
img2=image2,
|
|
label1=f"Name: {asset_id_1}",
|
|
label2=f"Name: {asset_id_2}",
|
|
width=700,
|
|
starting_position=50,
|
|
show_labels=True,
|
|
make_responsive=True,
|
|
in_memory=False,
|
|
)
|
|
|
|
col1, col2 = st.columns(2)
|
|
# with col1:
|
|
# st.image(image1, caption=f"Name: {asset_id_1}")
|
|
# with col2:
|
|
# st.image(image2, caption=f"Name: {asset_id_2}")
|
|
|
|
display_asset_column(col1, asset1_info, asset2_info, asset_id_1,asset_id_2, immich_server_url, api_key)
|
|
display_asset_column(col2, asset2_info, asset1_info, asset_id_2,asset_id_1, immich_server_url, api_key)
|
|
else:
|
|
st.write(f"Missing information for one or both assets: {asset_id_1}, {asset_id_2}")
|
|
|
|
st.markdown("---")
|
|
except:
|
|
st.write(f"Missing information for one or both assets")
|
|
progress_bar.progress(100)
|
|
else:
|
|
st.write("No duplicates found.")
|
|
|