From 4dbb3133871d758c13ede7f81b6ba15fc4b80a07 Mon Sep 17 00:00:00 2001
From: "raphael.stouder" <raphael.stouder@etu.hesge.ch>
Date: Tue, 21 Jan 2025 22:31:20 +0100
Subject: [PATCH] [feature] add GUI version

---
 tpGradientGUI.py | 541 +++++++++++++++++++++++++++++++++++++++++++++++
 1 file changed, 541 insertions(+)
 create mode 100644 tpGradientGUI.py

diff --git a/tpGradientGUI.py b/tpGradientGUI.py
new file mode 100644
index 0000000..d4a7cc1
--- /dev/null
+++ b/tpGradientGUI.py
@@ -0,0 +1,541 @@
+import itertools
+import sys
+import threading
+import time
+import numpy as np
+import matplotlib.pyplot as plt
+import pandas as pd
+from matplotlib.animation import FuncAnimation
+from matplotlib.colors import LinearSegmentedColormap
+import matplotlib.widgets as widgets
+from mpl_toolkits.mplot3d import Axes3D
+
+momentum = 0.8
+learn_rate = 0.00001
+maxIter = 10000
+tolerance = 1e-05
+initialPoint = [0, 0]
+
+learn_rate_textbox = None
+momentum_textbox = None
+maxIter_textbox = None
+tolerance_textbox = None
+initialPoint_x_textbox = None
+initialPoint_y_textbox = None
+
+# Créer une carte de couleur personnalisée pour les segments de l'animation
+colors = [(0, "green"), (0.5, "orange"), (1, "red")]
+custom_cmap = LinearSegmentedColormap.from_list("custom_cmap", colors)
+
+# Fonction pour calculer la dérivée numérique
+def numericalDerivative(f,x,y,eps):
+    gx = (f(x+eps, y) - f(x, y)) / (1 * eps)
+    gy = (f(x, y+eps) - f(x, y)) / (1 * eps)
+    return np.array([gx,gy])
+
+# Algorithme de descente de gradient simple
+def gradient_descent(cost, start, learn_rate, tolerance, n_iter):
+    df = pd.DataFrame(columns=["Iteration", "X", "Y", "Cost", "Gradient"])
+    position = np.array(start, dtype=float)
+    for i in range(n_iter):
+        # Calculer le gradient à la position actuelle
+        grad = numericalDerivative(cost, position[0], position[1], 1e-6)
+        # Mettre à jour la position en fonction du gradient et du taux d'apprentissage
+        new_pos = position - learn_rate * grad
+        # Enregistrer les informations de l'itération actuelle
+        df.loc[i] = [i, new_pos[0], new_pos[1], cost(new_pos[0], new_pos[1]), grad]
+        position = new_pos
+        # Arrêter si la norme du gradient est inférieure à la tolérance
+        if np.linalg.norm(grad) <= tolerance:
+            break
+    return df
+
+# Algorithme de descente de gradient avec momentum
+def gradient_descent2DMomentum(cost,start, learn_rate,momentum, tolerance,n_iter):
+    df = pd.DataFrame(columns=["Iteration", "X","Y","Cost","Stepx","Stepy",'NormGrad'])
+    vector = start
+    grad = numericalDerivative(cost, vector[0], vector[1], 0.00005)
+    normGrad = np.linalg.norm(grad) 
+    step=np.array([0,0])
+    
+    for k in range(n_iter):
+        # Enregistrer les informations de l'itération actuelle
+        df.loc[k]=[k,vector[0],vector[1],cost(vector[0],vector[1]),step[0],step[1],normGrad]
+        # Calculer le gradient à la position actuelle
+        grad = numericalDerivative(cost, vector[0], vector[1], 0.000005)
+        # Mettre à jour le pas en fonction du gradient, du taux d'apprentissage et du momentum
+        step = learn_rate*grad + momentum*step
+        # Mettre à jour la position
+        vector = vector - step
+        normGrad = np.linalg.norm(grad) 
+        # Arrêter si la norme du gradient est inférieure à la tolérance
+        if np.all(np.abs(normGrad) <= tolerance):
+            df.loc[k+1]=[k+1,vector[0],vector[1],cost(vector[0],vector[1]),step[0],step[1],normGrad]
+            break
+    return df
+
+# Algorithme de descente de gradient avec Nesterov
+def gradient_descent2DNesterov(cost, start, learn_rate, momentum, tolerance, n_iter):
+    df = pd.DataFrame(columns=["Iteration", "X", "Y", "Cost", "Stepx", "Stepy", "NormGrad"])
+    position = np.array(start, dtype=float)
+    velocity = np.zeros(2)
+    for i in range(n_iter):
+        # Calculer la position en avant
+        ahead = position - momentum * velocity
+        # Calculer le gradient à la position en avant
+        grad = numericalDerivative(cost, ahead[0], ahead[1], 1e-6)
+        # Mettre à jour la vélocité et la position
+        new_vel = momentum * velocity + learn_rate * grad
+        new_pos = position - new_vel
+        norm_g = np.linalg.norm(grad)
+        # Enregistrer les informations de l'itération actuelle
+        df.loc[i] = [i, new_pos[0], new_pos[1], cost(new_pos[0], new_pos[1]), new_vel[0], new_vel[1], norm_g]
+        position, velocity = new_pos, new_vel
+        # Arrêter si la norme du gradient est inférieure à la tolérance
+        if norm_g <= tolerance:
+            break
+    return df
+
+# Algorithme de descente de gradient avec Adam
+def gradient_descent2DAdam(cost, start, learn_rate, beta1, beta2, epsilon, tolerance, n_iter):
+    df = pd.DataFrame(columns=["Iteration", "X", "Y", "Cost", "Stepx", "Stepy", "NormGrad"])
+    position = np.array(start, dtype=float)
+    m = np.zeros(2)
+    v = np.zeros(2)
+    for i in range(n_iter):
+        # Calculer le gradient à la position actuelle
+        grad = numericalDerivative(cost, position[0], position[1], 1e-6)
+        # Mettre à jour les moments de premier et second ordre
+        m = beta1 * m + (1 - beta1) * grad
+        v = beta2 * v + (1 - beta2) * grad ** 2
+        # Calculer les moments corrigés
+        m_hat = m / (1 - beta1 ** (i + 1))
+        v_hat = v / (1 - beta2 ** (i + 1))
+        # Mettre à jour la position
+        new_pos = position - learn_rate * m_hat / (np.sqrt(v_hat) + epsilon)
+        norm_g = np.linalg.norm(grad)
+        # Enregistrer les informations de l'itération actuelle
+        df.loc[i] = [i, new_pos[0], new_pos[1], cost(new_pos[0], new_pos[1]), m_hat[0], m_hat[1], norm_g]
+        position = new_pos
+        # Arrêter si la norme du gradient est inférieure à la tolérance
+        if norm_g <= tolerance:
+            break
+    return df
+
+# Fonction pour choisir la fonction de coût
+def choose_function():
+    print("Choose a function:")
+    print("1) x2 + y2")
+    print("2) Beale")
+    print("3) Rosenbrock")
+    print("4) Himmelblau")
+    print("5) Ackley")
+    return input("Enter your choice (1-5): ")
+
+# Fonction pour choisir le type de descente de gradient
+def choose_gradient():
+    print("Choose gradient descent type:")
+    print("1) Simple")
+    print("2) Momentum")
+    print("3) Nesterov")
+    print("4) Adam")
+    return input("Enter your choice (1-4): ")
+
+# Fonction pour choisir le point initial
+def choose_initialPoint():
+    print("Choose initial point:")
+    print("1) Manual")
+    print("2) Random")
+    return input("Enter your choice (1-2): ")
+
+def choose_show_gradients():
+    print("Show gradients on projection ?")
+    print("1) Yes")
+    print("2) No")
+    return input("Enter your choice (1-2): ")
+
+done = False
+def terminal_loading():
+    global done
+    for c in itertools.cycle(["⢿", "⣻", "⣽", "⣾", "⣷", "⣯", "⣟", "⡿"]):
+        if done:
+            break
+        sys.stdout.write('\r' + c + ' Computing')
+        sys.stdout.flush()
+        time.sleep(0.1)
+    sys.stdout.write('\rDone!           \n\n')
+
+# Fonction pour exécuter les choix de l'utilisateur
+def compute_choices(func_choice, grad_choice, i_choice, sg_choice):
+    global done
+
+    # Choisir la fonction à optimiser
+    if func_choice == "1":
+        f = lambda x, y: x ** 2 + y ** 2
+        search_domain = [-10, 10]
+        trueMinPoints = [[0, 0]]
+    elif func_choice == "2":
+        f = lambda x, y: (1.5 - x + x*y)**2 + (2.25 - x + x*(y**2))**2 + (2.625 - x + x*(y**3))**2
+        search_domain = [-4.5, 4.5]
+        trueMinPoints = [[3, 0.5]]
+    elif func_choice == "3":
+        f = lambda x, y: (1 - x)**2 + 100 * (y - x**2)**2
+        search_domain = [-10, 10]
+        trueMinPoints = [[1, 1]]
+    elif func_choice == "4":
+        f = lambda x, y: (x**2 + y - 11)**2 + (x + y**2 - 7)**2
+        search_domain = [-5, 5]
+        trueMinPoints = [[3, 2], [-2.805118, 3.131312], [-3.779310, -3.283186], [3.584428, -1.848126]]
+    elif func_choice == "5":
+        f =  lambda x, y: -20 * np.exp(-0.2 * np.sqrt(0.5 * (x**2 + y**2))) \
+                 - np.exp(0.5 * (np.cos(2.0 * np.pi * x) + np.cos(2.0 * np.pi * y))) \
+                 + np.e + 20
+        search_domain = [-5, 5]
+        trueMinPoints = [[0, 0]]
+    
+    # Choisir le point initial
+    if i_choice == "1":
+        initialPoint = [float(input("Enter x: ")), float(input("Enter y: "))]
+    elif i_choice == "2":
+        initialPoint = np.random.uniform(search_domain[0], search_domain[1], 2)
+
+    done = False
+    t = threading.Thread(target=terminal_loading)
+    t.start()
+
+    done = True
+
+    # Choisir si on affiche les gradients sur la projection
+    if sg_choice == "1":
+        show_gradients = True
+    else:
+        show_gradients = False
+
+    # Afficher les résultats
+    path = df[["X", "Y"]].values
+    print_gradient_descent(initialPoint, df, trueMinPoints)
+    create_plots(f, initialPoint, path, trueMinPoints, show_gradients)
+
+def compute_gradient_descent(f, initialPoint, learn_rate, momentum, tolerance, n_iter, grad_choice):
+    # Exécuter la descente de gradient
+    if grad_choice == "Simple":
+        df = gradient_descent(f, initialPoint, learn_rate, tolerance, maxIter)
+    elif grad_choice == "Momentum":
+        df = gradient_descent2DMomentum(f, initialPoint, learn_rate, momentum, tolerance, maxIter)
+    elif grad_choice == "Nesterov":
+        df = gradient_descent2DNesterov(f, initialPoint, learn_rate, momentum, tolerance, maxIter)
+    elif grad_choice == "Adam":
+        df = gradient_descent2DAdam(f, initialPoint, learn_rate, 0.9, 0.999, 1e-8, tolerance, maxIter)
+
+    return df
+
+# Variables globales
+f = lambda x, y: x ** 2 + y ** 2
+trueMinPoints = [[0, 0]]
+search_domain = [-10, 10]
+grad_choice = "Simple"
+show_gradients = False
+
+ani = None  # Variable globale pour stocker l'animation
+
+def update_plot(event, choice):
+    global df, done, grad_choice, ax1, ax2, fig
+    grad_choice = choice
+    done = False
+    t = threading.Thread(target=terminal_loading)
+    t.start()
+
+    df = compute_gradient_descent(f, initialPoint, learn_rate, momentum, tolerance, maxIter, grad_choice)
+    path = df[["X", "Y"]].values
+    done = True
+
+    # Effacer les axes actuels
+    ax1.cla()
+    ax2.cla()
+
+    # Recréer les plots avec les nouvelles données
+    create_plots(f, initialPoint, path, trueMinPoints, show_gradients)
+
+    # Redessiner la figure
+    fig.canvas.draw()
+    plt.draw()  # Ajoutez cette ligne pour forcer le rafraîchissement de la figure
+
+def update_function(event, func_choice):
+    global f, trueMinPoints, search_domain
+    if func_choice == "x2 + y2":
+        f = lambda x, y: x ** 2 + y ** 2
+        search_domain = [-10, 10]
+        trueMinPoints = [[0, 0]]
+    elif func_choice == "Beale":
+        f = lambda x, y: (1.5 - x + x*y)**2 + (2.25 - x + x*(y**2))**2 + (2.625 - x + x*(y**3))**2
+        search_domain = [-4.5, 4.5]
+        trueMinPoints = [[3, 0.5]]
+    elif func_choice == "Rosenbrock":
+        f = lambda x, y: (1 - x)**2 + 100 * (y - x**2)**2
+        search_domain = [-10, 10]
+        trueMinPoints = [[1, 1]]
+    elif func_choice == "Himmelblau":
+        f = lambda x, y: (x**2 + y - 11)**2 + (x + y**2 - 7)**2
+        search_domain = [-5, 5]
+        trueMinPoints = [[3, 2], [-2.805118, 3.131312], [-3.779310, -3.283186], [3.584428, -1.848126]]
+    elif func_choice == "Ackley":
+        f =  lambda x, y: -20 * np.exp(-0.2 * np.sqrt(0.5 * (x**2 + y**2))) \
+                 - np.exp(0.5 * (np.cos(2.0 * np.pi * x) + np.cos(2.0 * np.pi * y))) \
+                 + np.e + 20
+        search_domain = [-5, 5]
+        trueMinPoints = [[0, 0]]
+
+def update_global_vars(event):
+    global learn_rate, momentum, maxIter, tolerance, initialPoint
+    global learn_rate_textbox, momentum_textbox, maxIter_textbox, tolerance_textbox, initialPoint_x_textbox, initialPoint_y_textbox
+
+    learn_rate = float(learn_rate_textbox.text)
+    momentum = float(momentum_textbox.text)
+    maxIter = int(maxIter_textbox.text)
+    tolerance = float(tolerance_textbox.text)
+    initialPoint = [float(initialPoint_x_textbox.text), float(initialPoint_y_textbox.text)]
+
+# Fonction pour imprimer les résultats de la descente de gradient
+def print_gradient_descent(initialPoint, df, trueMinPoints):
+    print("\nGradient Descent Path:")
+    #print(df.to_markdown())
+    print(df.head(100).to_markdown())
+    print("...")
+    print(df.tail(100).to_markdown())
+    print("\nInitial point: ({}, {})".format(initialPoint[0], initialPoint[1]))
+    print("Final founded point: ({}, {})".format(df["X"].iloc[-1], df["Y"].iloc[-1]))
+    print("True minimum points: {}".format(trueMinPoints))
+    print("Minimum founded value: {}".format(df["Cost"].iloc[-1]))
+
+# Fonction pour créer les graphiques
+def create_plots(function, initialPoint, path, trueMinPoints, show_gradients):
+    global ax1, ax2, fig, ani
+    # Configurer le premier sous-plot (3D)
+    scale = 10
+    X, Y = np.mgrid[-scale:scale:30j, -scale:scale:30j]
+    Z = function(X, Y)
+    ax1.plot_surface(X, Y, Z, cmap='viridis', alpha=0.8)
+    ax1.set_title('3D Plot')
+    ax1.set_xlabel('x')
+    ax1.set_ylabel('y')
+    ax1.set_zlabel('f(x, y)')
+
+    # Configurer le deuxième sous-plot (contour)
+    ax2.contour(X, Y, Z, 20, cmap='plasma')
+    ax2.set_title('Contour Plot')
+    ax2.set_xlabel('x')
+    ax2.set_ylabel('y')
+
+    # Ajouter les points initiaux et finaux aux plots
+    ax1.plot([initialPoint[0]], [initialPoint[1]], [function(initialPoint[0], initialPoint[1])], 'mp', zorder=7)
+    ax2.plot([initialPoint[0]], [initialPoint[1]], 'mp', zorder=7)
+    ax1.plot([path[-1][0]], [path[-1][1]], [function(path[-1][0], path[-1][1])], 'cd', zorder=7)
+    ax2.plot([path[-1][0]], [path[-1][1]], 'cd', zorder=7)
+
+    # Ajouter les points de minimums globaux aux plots
+    if trueMinPoints:
+        for point in trueMinPoints:
+            ax1.plot([point[0]], [point[1]], [function(point[0], point[1])], 'kX', zorder=7)
+            ax2.plot([point[0]], [point[1]], 'kX', zorder=7)
+
+    # Créer l'animation du chemin de descente de gradient
+    lines1 = []
+    lines2 = []
+
+    # Définir une distance minimale pour filtrer les points trop proches
+    min_distance = 0.01
+
+    # Filtrer les points trop proches
+    filtered_path = [path[0]]
+    for point in path[1:]:
+        if np.linalg.norm(point - filtered_path[-1]) >= min_distance:
+            filtered_path.append(point)
+    filtered_path = np.array(filtered_path)
+
+    # Rajouter le point initial au début du chemin
+    filtered_path = np.vstack([initialPoint, filtered_path])
+
+    #DEBUG
+    print("Original path length: {}".format(len(path)))
+    print("Filtered path length: {}".format(len(filtered_path)))
+
+    # Precalculate z values for each point in the path
+    zdata_path = np.array([function(point[0], point[1]) for point in filtered_path])
+
+    # Precalculate step sizes and colors for each segment
+    step_sizes = np.array([np.linalg.norm(filtered_path[i] - filtered_path[i-1]) for i in range(1, len(filtered_path))])
+    max_step_size = np.max(step_sizes)
+    colors = np.array([custom_cmap(1 - min(step_size / max_step_size, 1)) for step_size in step_sizes])
+
+    for _ in range(len(filtered_path)):
+        line1, = ax1.plot([], [], [], 'r.-', markersize=5, lw=1, zorder=6)
+        line2, = ax2.plot([], [], 'r.-', markersize=5, lw=1, zorder=6)
+        lines1.append(line1)
+        lines2.append(line2)
+
+    def init():
+        for line1, line2 in zip(lines1, lines2):
+            line1.set_data([initialPoint[0]], [initialPoint[1]])
+            line1.set_3d_properties([function(initialPoint[0], initialPoint[1])])
+            line2.set_data([initialPoint[0]], [initialPoint[1]])
+        return lines1 + lines2
+
+    def update(frame):
+        if frame > 0:
+            xdata = filtered_path[frame-1:frame+1, 0]
+            ydata = filtered_path[frame-1:frame+1, 1]
+            zdata = zdata_path[frame-1:frame+1]
+            
+            color = colors[frame-1]
+            
+            lines1[frame-1].set_data(xdata, ydata)
+            lines1[frame-1].set_3d_properties(zdata)
+            lines1[frame-1].set_color(color)
+            
+            lines2[frame-1].set_data(xdata, ydata)
+            lines2[frame-1].set_color(color)
+        
+        return lines1 + lines2
+
+    # Calculer l'intervalle pour que l'animation dure 3 secondes
+    total_frames = len(filtered_path)
+    interval = 3000 / total_frames  # 3 secondes divisées par le nombre de frames
+
+    # Ajouter les flèches de gradients au plot de contour
+    if show_gradients:
+        all_grads = []
+        for x in range(-scale, scale + 1, 1):
+            for y in range(-scale, scale + 1, 1):
+                grad_temp = numericalDerivative(function, x, y, 1e-6)
+                all_grads.append(np.linalg.norm(grad_temp))
+        max_grad = max(all_grads) if all_grads else 1
+
+        min_arrow_factor = 0.3  # Taille minimale
+        for x in range(-scale, scale + 1, 2):
+            for y in range(-scale, scale + 1, 2):
+                if x in {-scale, scale} or y in {-scale, scale}:
+                    continue
+                grad = numericalDerivative(function, x, y, 1e-6)
+                norm_grad = np.linalg.norm(grad)
+                if norm_grad != 0 and max_grad != 0:
+                    arrow_vec = grad / max_grad
+                    arrow_length = np.linalg.norm(arrow_vec)
+
+                    # appliquer la taille minimale
+                    if arrow_length < min_arrow_factor:
+                        arrow_vec *= min_arrow_factor / arrow_length
+                        arrow_length = min_arrow_factor
+
+                    head_size = 0.5 * arrow_length
+                    ax2.arrow(
+                        x, y,
+                        arrow_vec[0], arrow_vec[1],
+                        head_width=head_size,
+                        head_length=head_size,
+                        fc='k', ec='k', zorder=5
+                    )
+
+    # Arrêter l'animation actuelle si elle existe
+    # if ani:
+        # ani.event_source.stop()
+
+    
+
+    # Afficher les résultats de la descente de gradient dans le terminal
+    print_gradient_descent(initialPoint, df, trueMinPoints)
+
+    # Créer l'animation
+    ani = FuncAnimation(fig, update, frames=total_frames, init_func=init, blit=True, repeat=False, interval=interval)
+
+def toggle_show_gradients(event):
+    global show_gradients
+    show_gradients = not show_gradients
+    print(f"Show gradients: {show_gradients}")
+
+def create_empty_plots():
+    global ax1, ax2, fig
+    global learn_rate_textbox, momentum_textbox, maxIter_textbox, tolerance_textbox, initialPoint_x_textbox, initialPoint_y_textbox
+    # Créer une seule figure avec deux sous-plots
+    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 6))
+
+    # Configurer le premier sous-plot (3D)
+    ax1 = fig.add_subplot(121, projection='3d')
+    ax1.set_title('3D Plot')
+    ax1.set_xlabel('x')
+    ax1.set_ylabel('y')
+    ax1.set_zlabel('f(x, y)')
+
+    # Configurer le deuxième sous-plot (contour)
+    ax2.set_title('Contour Plot')
+    ax2.set_xlabel('x')
+    ax2.set_ylabel('y')
+
+    # Ajouter les boutons pour les types de descente de gradient
+    ax_button1 = plt.axes([0.01, 0.8, 0.1, 0.05])
+    ax_button2 = plt.axes([0.01, 0.7, 0.1, 0.05])
+    ax_button3 = plt.axes([0.01, 0.6, 0.1, 0.05])
+    ax_button4 = plt.axes([0.01, 0.5, 0.1, 0.05])
+
+    button1 = widgets.Button(ax_button1, 'Simple')
+    button2 = widgets.Button(ax_button2, 'Momentum')
+    button3 = widgets.Button(ax_button3, 'Nesterov')
+    button4 = widgets.Button(ax_button4, 'Adam')
+
+    button1.on_clicked(lambda event: update_plot(event, 'Simple'))
+    button2.on_clicked(lambda event: update_plot(event, 'Momentum'))
+    button3.on_clicked(lambda event: update_plot(event, 'Nesterov'))
+    button4.on_clicked(lambda event: update_plot(event, 'Adam'))
+
+    # Ajouter les boutons pour les types de fonctions
+    ax_func_button1 = plt.axes([0.85, 0.8, 0.1, 0.05])
+    ax_func_button2 = plt.axes([0.85, 0.7, 0.1, 0.05])
+    ax_func_button3 = plt.axes([0.85, 0.6, 0.1, 0.05])
+    ax_func_button4 = plt.axes([0.85, 0.5, 0.1, 0.05])
+    ax_func_button5 = plt.axes([0.85, 0.4, 0.1, 0.05])
+
+    func_button1 = widgets.Button(ax_func_button1, 'x2 + y2')
+    func_button2 = widgets.Button(ax_func_button2, 'Beale')
+    func_button3 = widgets.Button(ax_func_button3, 'Rosenbrock')
+    func_button4 = widgets.Button(ax_func_button4, 'Himmelblau')
+    func_button5 = widgets.Button(ax_func_button5, 'Ackley')
+
+    func_button1.on_clicked(lambda event: update_function(event, 'x2 + y2'))
+    func_button2.on_clicked(lambda event: update_function(event, 'Beale'))
+    func_button3.on_clicked(lambda event: update_function(event, 'Rosenbrock'))
+    func_button4.on_clicked(lambda event: update_function(event, 'Himmelblau'))
+    func_button5.on_clicked(lambda event: update_function(event, 'Ackley'))
+
+    # Ajouter les champs de texte pour les variables globales
+    ax_learn_rate = plt.axes([0.85, 0.3, 0.1, 0.05])
+    ax_momentum = plt.axes([0.85, 0.25, 0.1, 0.05])
+    ax_maxIter = plt.axes([0.85, 0.2, 0.1, 0.05])
+    ax_tolerance = plt.axes([0.85, 0.15, 0.1, 0.05])
+    ax_initialPoint_x = plt.axes([0.85, 0.1, 0.05, 0.05])
+    ax_initialPoint_y = plt.axes([0.9, 0.1, 0.05, 0.05])
+
+    learn_rate_textbox = widgets.TextBox(ax_learn_rate, 'Learn Rate', initial=str(learn_rate))
+    momentum_textbox = widgets.TextBox(ax_momentum, 'Momentum', initial=str(momentum))
+    maxIter_textbox = widgets.TextBox(ax_maxIter, 'Max Iter', initial=str(maxIter))
+    tolerance_textbox = widgets.TextBox(ax_tolerance, 'Tolerance', initial=str(tolerance))
+    initialPoint_x_textbox = widgets.TextBox(ax_initialPoint_x, 'Init X', initial=str(initialPoint[0]))
+    initialPoint_y_textbox = widgets.TextBox(ax_initialPoint_y, 'Init Y', initial=str(initialPoint[1]))
+
+    # Ajouter le bouton pour mettre à jour les variables globales
+    ax_update_button = plt.axes([0.85, 0.05, 0.1, 0.05])
+    update_button = widgets.Button(ax_update_button, 'Update')
+    update_button.on_clicked(update_global_vars)
+
+    # Ajouter le bouton pour activer/désactiver l'affichage des flèches de gradient
+    ax_toggle_gradients_button = plt.axes([0.01, 0.05, 0.1, 0.05])
+    toggle_gradients_button = widgets.Button(ax_toggle_gradients_button, 'Toggle Gradients')
+    toggle_gradients_button.on_clicked(toggle_show_gradients)
+
+    plt.subplots_adjust(left=0.15, right=0.8, top=0.9, bottom=0.1)
+    plt.show()
+
+# Demander les choix de l'utilisateur et exécuter le programme
+def main():
+    create_empty_plots()
+
+if __name__ == "__main__":
+    main()
\ No newline at end of file
-- 
GitLab