#!/usr/bin/python3
import csv
import cv2
from math import sqrt
import numpy as np
import os
from random import randint, sample
import shutil



def stat_dataset(nom_fichier:str):
    """
    Fonction:
        Etablit les statistiques de notre base de données, selon un fichier csv fourni en paramètre.
    Paramètres:
        - nom_fichier: chemin relatif/absolu du fichier csv.
    Renvois:
        - Dictionnaire:
            - Clé:
                - Type: int.
                - Description: classe, catégorie.
            - Valeur:
                - Type: int.
                - Description: nombre d'éléments.
    """
    stats = {}
    # Lecture du fichier csv.
    with open(nom_fichier, mode='r') as file:
        reader = csv.reader(file)
        for row in reader:
            # Les labels traités sont uniquement des chiffres.
            # On identifie donc les labels en tant qu'entiers pour simplifier les traitements.
            try:
                label = int(row[1])
                if label not in stats.keys():
                    stats[label] = 1
                else:
                    stats[label] += 1
            # Si le fichier csv contient une première ligne décrivant les champs du fichiers csv, une erreur sera indiquée.
            except ValueError as error:
                print(f"Unprocessed data at row: {row}\nDue to {error}")
    return dict(stats)



def affichage_exemple(nom_fichier:str, nom_repertoire:str, nbr:int):
    """
    Fonction:
        Affiche un échantillon d'exemples (images) aléatoires d'une base de données (répertoire d'images).
        Les exemples sont sélectionnés aléatoirement parmi la base de données fournie.
    Paramètres:
        - nom_fichier: chemin relatif/absolu du fichier csv.
        - nom_repertoire: chemin relatif/absolu de la base de données.
            --> Selon le fichier csv fourni et la manière dont cette fonction a été programmée en conséquence, ce paramètre est en réalité inutile.
        - nbr: nombre aléatoire d'exemples à sélectionner dans la base de données.
    Renvois:
        - Affichage des exemples sélectionnés dans une même fenêtre.
        - Dictionnaire:
            - Clé:
                - Type: string.
                - Description: nom de l'image.
            - Valeur:
                - Type: list.
                - Description:
                    - 0: emplacement de l’image (string).
                    - 1: label de l'image (string).
    """
    # Lecture du fichier csv.
    with open(nom_fichier, mode='r') as file:
        reader = csv.reader(file)
        data = list(reader)
    # Sélection aléatoire d'exemples, puis traitement des données sélectionnées.
    random_samples = sample(data, nbr)
    returned_data = {}
    displayed_samples = {}
    for row in random_samples:
        image_path = row[0]
        image_name = image_path.split("/")[-1].split(".")[0]
        label = int(row[1])
        # Chargement et redimensionnement en x4 de l'image (carré de 28x28 pixels passant à 112x112).
        image_view = cv2.imread(image_path)
        if image_view is not None:
            image_view = cv2.resize(image_view, (112, 112))
            returned_data[image_name] = [image_path, label]
            displayed_samples[image_name] = [image_path, label, image_view]
        else:
            print("Unable to load image:", image_path)
    
    # Définition des paramètres d'affichage.
    # Nombre de colonnes: racine carrée du carré parfait supérieur ou égal au nombre d'exemples à afficher.
    lower_sqrt = int(sqrt(len(displayed_samples)))
    cols = lower_sqrt +1 if lower_sqrt != sqrt(len(displayed_samples)) else lower_sqrt
    # Espacement horizontal/vertical entre les exemples affichés.
    sample_spacing = 10
    # Espacement vertical entre un exemple affiché et son label.
    label_spacing = 20
    # Taille d'un exemple (carré).
    sample_size = 112
    # Création d'une matrice d'affichage.
    display_window = np.zeros(
        ((sample_size +sample_spacing +label_spacing) *((len(displayed_samples) +cols -1) //cols) + sample_spacing,
         (sample_size +sample_spacing) *cols, 3),
        dtype=np.uint8)
    # Parcours et traitement des exemples à afficher.
    for index, (image_name, data) in enumerate(displayed_samples.items()):
        # Position de l'exemple dans l'affichage.
        x_pos = (index %cols) *(sample_size +sample_spacing)
        y_pos = (index //cols) *(sample_size +sample_spacing +label_spacing)
        # Placement de l'exemple dans la fenêtre.
        display_window[y_pos:y_pos +sample_size, x_pos:x_pos +sample_size] = data[2]
        # Ajout d'un contour rouge pour identifier clairement l'exemple affiché.
        cv2.rectangle(display_window, (x_pos, y_pos), (x_pos +sample_size, y_pos +sample_size), (0, 0, 255), 2)
        # Affichage du label associé à l'exemple affiché.
        text = str(data[1])
        # Position du texte, centré sous l'exemple.
        org = (x_pos + sample_size//2, y_pos + sample_size + label_spacing)
        font = cv2.FONT_HERSHEY_SIMPLEX
        fontScale = 0.5
        color = (255, 255, 255)
        thickness = 1
        cv2.putText(display_window, text, org, font, fontScale, color, thickness, cv2.LINE_AA)

    # Afficher l'image composite (matrice d'affichage).
    cv2.imshow('Exemples', display_window)
    # Attendre qu'une touche soit pressée pour fermer l'affichage.
    cv2.waitKey(0)
    cv2.destroyAllWindows()
    return returned_data



def re_organisation(chemin_fichier:str, chemin_repertoire:str):
    """
    Fonction:
        Réorganise la base de données décrite par un fichier csv selon un répertoire de destination et des sous-répertoires dédiés aux labels identifiés.
        Les répertoires manquants seront créés lors de l'exécution.
    Paramètres:
        - chemin_fichier: chemin relatif/absolu du fichier csv.
        - chemin_repertoire: chemin relatif/absolu du répertoire de destination.
    Renvois: None
    """
    # Création du répertoire de destination si celui-ci n'existe pas.
    make_dir(chemin_repertoire)
    # Lecture du fichier csv.
    with open(chemin_fichier, mode='r') as file:
        reader = csv.reader(file)
        data = list(reader)
    
    # Extraction des données du fichier csv.
    samples = {}
    for row in data:
        # Les labels traités sont uniquement des chiffres.
        # On identifie donc les labels en tant qu'entiers pour simplifier les traitements.
        try:
            image_path = row[0]
            image_name = image_path.split("/")[-1].split(".")[0]
            label = int(row[1])
            samples[image_name] = [image_path, label]
        # Si le fichier csv contient une première ligne décrivant les champs du fichiers csv, une erreur sera indiquée.
        except ValueError as error:
            print(f"Unprocessed data at row: {row}\nDue to {error}")
    
    # Réorganisation du système de fichiers selon les labels trouvés dans les données extraites.
    # Utilisation d'un ensemble pour éviter les doublons.
    identified_labels = set()
    for sample in samples.keys():
        label = str(samples[sample][1])
        # Pour chaque label trouvé, on crée un sous-répertoire dans le répertoire de destination, où l'on déplacera les données associées au label identifié.
        if label not in identified_labels:
            identified_labels.add(label)
            make_dir(os.path.join(chemin_repertoire, label))
        # Déplacement des données dans les répertoires correspondants.
        filepath = samples[sample][0]
        move_file(filepath, os.path.join(chemin_repertoire, label))
    return None



def make_dir(chemin_repertoire:str):
    """
    Fonction:
        Crée un répertoire si celui-ci n'existe pas.
    Paramètres:
        - chemin_repertoire: chemin relatif/absolu du répertoire.
    Renvois: None
    """
    if not os.path.exists(chemin_repertoire):
        os.makedirs(chemin_repertoire)
    return None



def move_file(chemin_fichier:str, chemin_repertoire:str):
    """
    Fonction:
        Déplace un fichier dans un répertoire de destination.
    Paramètres:
        - chemin_fichier: chemin relatif/absolu du fichier.
        - chemin_repertoire: chemin relatif/absolu du répertoire de destination.
    Renvois: None
    """
    if os.path.exists(chemin_fichier):
        if os.path.exists(chemin_repertoire):
            try:
                shutil.move(chemin_fichier, chemin_repertoire)
            except:
                print("Specified directory is not a directory or can't be accessed.")
        else:
            print(f"Directory {chemin_repertoire} doesn't exist.")
    else:
        print(f"File {chemin_fichier} doesn't exist.")
    return None



if __name__ == '__main__':
    test = ['test_data.csv', 'dataset/test/']
    train = ['train_data.csv', 'dataset/train/']

    #stats_test = stat_dataset(test[0])
    #stats_train = stat_dataset(train[0])
    #print("Stats test:", stats_test, sep="\n")
    #print("Stats apprentissage:", stats_train, sep="\n")

    #exemples = affichage_exemple(train[0], train[1], randint(2, 25))
    #print(exemples)
    
    #re_organisation(test[0], test[1])
    #re_organisation(train[0], train[1])

