diff --git a/main.py b/main.py index 4309779e648637e77e56235e23bcd8fac9382b49..893bafe3de9aea628219a5c313c0c71ee7c340ad 100644 --- a/main.py +++ b/main.py @@ -83,18 +83,18 @@ if __name__ == '__main__': ############################################################################### # Génère une mélodie ############################################################################### - if corpus.parser.min_bpm > corpus.parser.max_bpm: - min_temp = corpus.parser.min_bpm - corpus.parser.min_bpm = corpus.parser.max_bpm - corpus.parser.max_bpm = min_temp - bpm = np.random.randint(corpus.parser.min_bpm, corpus.parser.max_bpm + 1) - mg.generate( - model=model, - ntokens=ntokens, - corpus=corpus, - nb_words=nb_words, - temperature=temperature, - device=device, - data_path=data_path, - bpm=bpm - ) + # if corpus.parser.min_bpm > corpus.parser.max_bpm: + # min_temp = corpus.parser.min_bpm + # corpus.parser.min_bpm = corpus.parser.max_bpm + # corpus.parser.max_bpm = min_temp + # bpm = np.random.randint(corpus.parser.min_bpm, corpus.parser.max_bpm + 1) + # mg.generate( + # model=model, + # ntokens=ntokens, + # corpus=corpus, + # nb_words=nb_words, + # temperature=temperature, + # device=device, + # data_path=data_path, + # bpm=bpm + # ) diff --git a/parameter_finder.py b/parameter_finder.py index d98fc516f0785d6d9b1471b52a49324c29cee7a3..8ab9c48cebfd5593368af54f207a5ac69a99ce1b 100644 --- a/parameter_finder.py +++ b/parameter_finder.py @@ -3,6 +3,7 @@ import subprocess import itertools import time import re +import os lr = ['0.1', '1', '100'] epochs = ['20', '50', '100'] @@ -15,6 +16,15 @@ nhid = ['100', '200', '500'] nlayers = ['2', '6', '10'] if __name__ == '__main__': + venv_path = 'venv' + + # Path to the virtual environment's bin or Scripts directory + if os.name == 'nt': # For Windows + venv_python = os.path.join(venv_path, 'Scripts', 'python') + else: # For Unix or MacOS + venv_python = os.path.join(venv_path, 'bin', 'python') + + best_ppl = None best_config = '' nb_tests = 0 @@ -24,7 +34,7 @@ if __name__ == '__main__': nb_tests += 1 print("Start config :", res_name) start_time = time.time() - res = subprocess.run(['python', 'main.py', + res = subprocess.run([venv_python, 'main.py', '--lr', lr_i, '--epochs', epoch_i, '--batch_size', batch_size_i, @@ -53,8 +63,6 @@ if __name__ == '__main__': print(res_name, test_ppl, "time :", end_time - start_time) else: print("ERROR:", res_stdout, res_stderr) - if nb_tests >= 15: - break print("nb tests :", nb_tests) print("best config :", best_config)