import numpy as np
from sklearn.neural_network import MLPClassifier
from sklearn.metrics import classification_report, accuracy_score
from sklearn.model_selection import train_test_split
from keras.datasets import cifar10
from datetime import datetime
import matplotlib.pyplot as plt

# Charger le jeu de données CIFAR-10
print(f"Début du chargement des données {datetime.now()}")
(X_train, y_train), (X_test, y_test) = cifar10.load_data()

# Normaliser les données
X_train = X_train.astype('float32') / 255.0
X_test = X_test.astype('float32') / 255.0

# Aplatir les images de 32x32x3 pixels en vecteurs de 3072 éléments
X_train = X_train.reshape(-1, 32*32*3)
X_test = X_test.reshape(-1, 32*32*3)

# Convertir les étiquettes en un format compatible
y_train = y_train.flatten()
y_test = y_test.flatten()
print(f"Fin du chargement des données {datetime.now()}")

# Créer un réseau de neurones classifieur 

# Extrait de la doc :
# The solver iterates until convergence (determined by `tol`) or the number of iterations (`max_iter`).
# For stochastic solvers (`sgd`, `adam`), note that this determines the number of epochs (how many times each data point will be used), not the number of gradient steps.
print(f"Début de l'entraînement {datetime.now()}")
mlp = MLPClassifier(
    hidden_layer_sizes=(256, 128, 64),
    max_iter=100, # limite d'iterations
    batch_size=128,
    learning_rate_init=0.001,
    random_state=42,
    verbose=1, # pour voir la progression de l'entrainement
)

# Entraîner le modèle
mlp.fit(X_train, y_train)
print(f"Modèle entraîné - fin de l'entrainement {datetime.now()}")

# Faire des prédictions
y_pred = mlp.predict(X_test)
print(f"Fin des prédictions {datetime.now()}")

# Évaluer le modèle
accuracy = accuracy_score(y_test, y_pred)
report = classification_report(y_test, y_pred, target_names=['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck'])

print(f"Fin de l'évaluation {datetime.now()}")
print(f"Précision du modèle : {accuracy:.2f}")
print("Rapport de classification :")
print(report)

# Visualiser quelques exemples d'images avec leurs prédictions
# fig, axes = plt.subplots(1, 5, figsize=(15, 8), subplot_kw={'xticks':[], 'yticks':[]})
# for i, ax in enumerate(axes):
#     ax.imshow(X_test[i].reshape(32, 32, 3))
#     ax.set_title(f"True: {y_test[i]}\nPred: {y_pred[i]}")
# plt.show()
