Skip to content
Snippets Groups Projects
Commit 79bddd7f authored by lucien.noel's avatar lucien.noel
Browse files

utilise le venv pour tester le programme

parent 172c3de3
Branches
No related tags found
No related merge requests found
......@@ -83,18 +83,18 @@ if __name__ == '__main__':
###############################################################################
# Génère une mélodie
###############################################################################
if corpus.parser.min_bpm > corpus.parser.max_bpm:
min_temp = corpus.parser.min_bpm
corpus.parser.min_bpm = corpus.parser.max_bpm
corpus.parser.max_bpm = min_temp
bpm = np.random.randint(corpus.parser.min_bpm, corpus.parser.max_bpm + 1)
mg.generate(
model=model,
ntokens=ntokens,
corpus=corpus,
nb_words=nb_words,
temperature=temperature,
device=device,
data_path=data_path,
bpm=bpm
)
# if corpus.parser.min_bpm > corpus.parser.max_bpm:
# min_temp = corpus.parser.min_bpm
# corpus.parser.min_bpm = corpus.parser.max_bpm
# corpus.parser.max_bpm = min_temp
# bpm = np.random.randint(corpus.parser.min_bpm, corpus.parser.max_bpm + 1)
# mg.generate(
# model=model,
# ntokens=ntokens,
# corpus=corpus,
# nb_words=nb_words,
# temperature=temperature,
# device=device,
# data_path=data_path,
# bpm=bpm
# )
......@@ -3,6 +3,7 @@ import subprocess
import itertools
import time
import re
import os
lr = ['0.1', '1', '100']
epochs = ['20', '50', '100']
......@@ -15,6 +16,15 @@ nhid = ['100', '200', '500']
nlayers = ['2', '6', '10']
if __name__ == '__main__':
venv_path = 'venv'
# Path to the virtual environment's bin or Scripts directory
if os.name == 'nt': # For Windows
venv_python = os.path.join(venv_path, 'Scripts', 'python')
else: # For Unix or MacOS
venv_python = os.path.join(venv_path, 'bin', 'python')
best_ppl = None
best_config = ''
nb_tests = 0
......@@ -24,7 +34,7 @@ if __name__ == '__main__':
nb_tests += 1
print("Start config :", res_name)
start_time = time.time()
res = subprocess.run(['python', 'main.py',
res = subprocess.run([venv_python, 'main.py',
'--lr', lr_i,
'--epochs', epoch_i,
'--batch_size', batch_size_i,
......@@ -53,8 +63,6 @@ if __name__ == '__main__':
print(res_name, test_ppl, "time :", end_time - start_time)
else:
print("ERROR:", res_stdout, res_stderr)
if nb_tests >= 15:
break
print("nb tests :", nb_tests)
print("best config :", best_config)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment