From 79bddd7f27370758b7aa86ddffc60658e2db79e4 Mon Sep 17 00:00:00 2001
From: "lucien.noel" <lucien.noel@etu.hesge.ch>
Date: Mon, 15 Jul 2024 21:06:59 +0200
Subject: [PATCH] utilise le venv pour tester le programme

---
 main.py             | 30 +++++++++++++++---------------
 parameter_finder.py | 14 +++++++++++---
 2 files changed, 26 insertions(+), 18 deletions(-)

diff --git a/main.py b/main.py
index 4309779..893bafe 100644
--- a/main.py
+++ b/main.py
@@ -83,18 +83,18 @@ if __name__ == '__main__':
     ###############################################################################
     # Génère une mélodie
     ###############################################################################
-    if corpus.parser.min_bpm > corpus.parser.max_bpm:
-        min_temp = corpus.parser.min_bpm
-        corpus.parser.min_bpm = corpus.parser.max_bpm
-        corpus.parser.max_bpm = min_temp
-    bpm = np.random.randint(corpus.parser.min_bpm, corpus.parser.max_bpm + 1)
-    mg.generate(
-        model=model,
-        ntokens=ntokens,
-        corpus=corpus,
-        nb_words=nb_words,
-        temperature=temperature,
-        device=device,
-        data_path=data_path,
-        bpm=bpm
-    )
+    # if corpus.parser.min_bpm > corpus.parser.max_bpm:
+    #     min_temp = corpus.parser.min_bpm
+    #     corpus.parser.min_bpm = corpus.parser.max_bpm
+    #     corpus.parser.max_bpm = min_temp
+    # bpm = np.random.randint(corpus.parser.min_bpm, corpus.parser.max_bpm + 1)
+    # mg.generate(
+    #     model=model,
+    #     ntokens=ntokens,
+    #     corpus=corpus,
+    #     nb_words=nb_words,
+    #     temperature=temperature,
+    #     device=device,
+    #     data_path=data_path,
+    #     bpm=bpm
+    # )
diff --git a/parameter_finder.py b/parameter_finder.py
index d98fc51..8ab9c48 100644
--- a/parameter_finder.py
+++ b/parameter_finder.py
@@ -3,6 +3,7 @@ import subprocess
 import itertools
 import time
 import re
+import os
 
 lr = ['0.1', '1', '100']
 epochs = ['20', '50', '100']
@@ -15,6 +16,15 @@ nhid = ['100', '200', '500']
 nlayers = ['2', '6', '10']
 
 if __name__ == '__main__':
+    venv_path = 'venv'
+
+    # Path to the virtual environment's bin or Scripts directory
+    if os.name == 'nt':  # For Windows
+        venv_python = os.path.join(venv_path, 'Scripts', 'python')
+    else:  # For Unix or MacOS
+        venv_python = os.path.join(venv_path, 'bin', 'python')
+
+
     best_ppl = None
     best_config = ''
     nb_tests = 0
@@ -24,7 +34,7 @@ if __name__ == '__main__':
         nb_tests += 1
         print("Start config :", res_name)
         start_time = time.time()
-        res = subprocess.run(['python', 'main.py',
+        res = subprocess.run([venv_python, 'main.py',
                               '--lr', lr_i,
                               '--epochs', epoch_i,
                               '--batch_size', batch_size_i,
@@ -53,8 +63,6 @@ if __name__ == '__main__':
             print(res_name, test_ppl, "time :", end_time - start_time)
         else:
             print("ERROR:", res_stdout, res_stderr)
-        if nb_tests >= 15:
-            break
 
     print("nb tests :", nb_tests)
     print("best config :", best_config)
-- 
GitLab