diff --git a/parameter_finder.py b/parameter_finder.py index 8ab9c48cebfd5593368af54f207a5ac69a99ce1b..05953c2dcade734defa5eef21bd685185b525ea1 100644 --- a/parameter_finder.py +++ b/parameter_finder.py @@ -12,9 +12,21 @@ sequence_length = ['8', '32', '128'] dimension_model = ['64', '256', '512'] nhead = ['2', '4', '8'] dropout = ['0.0', '0.3', '0.6'] -nhid = ['100', '200', '500'] +nhid = ['50', '100', '500'] nlayers = ['2', '6', '10'] +params = { + '--lr': ['0.1', '1', '100'], + '--epochs': ['20', '50', '100'], + '--batch_size': ['4', '16', '64'], + '--sequence_length': ['8', '32', '128'], + '--dimension_model': ['64', '256', '512'], + '--nhead': ['2', '4', '8'], + '--dropout': ['0.0', '0.3', '0.6'], + '--nhid': ['50', '100', '500'], + '--nlayers': ['2', '6', '10'], +} + if __name__ == '__main__': venv_path = 'venv' @@ -24,46 +36,17 @@ if __name__ == '__main__': else: # For Unix or MacOS venv_python = os.path.join(venv_path, 'bin', 'python') - - best_ppl = None - best_config = '' - nb_tests = 0 - for lr_i, epoch_i, batch_size_i, sequence_length_i, dimension_model_i, nhead_i, dropout_i, nhid_i, nlayers_i in itertools.product( - lr, epochs, batch_size, sequence_length, dimension_model, nhead, dropout, nhid, nlayers): - res_name = f"lr{str.replace(lr_i, '.', '')}_epoch{epoch_i}_batch{batch_size_i}_seq{sequence_length_i}_dim{dimension_model_i}_nhead{nhead_i}_drop{str.replace(dropout_i, '.', '')}_nhid{nhid_i}_nlay{nlayers_i}" - nb_tests += 1 - print("Start config :", res_name) - start_time = time.time() - res = subprocess.run([venv_python, 'main.py', - '--lr', lr_i, - '--epochs', epoch_i, - '--batch_size', batch_size_i, - '--sequence_length', sequence_length_i, - '--dimension_model', dimension_model_i, - '--nhead', nhead_i, - '--dropout', dropout_i, - '--nhid', nhid_i, - '--nlayers', nlayers_i - ], - stdout=subprocess.PIPE, - stderr=subprocess.PIPE - ) - end_time = time.time() - res_stdout = res.stdout.decode() - res_stderr = res.stderr.decode() - match = re.search(r'test ppl\s+([0-9.]+)', res_stdout) - if match: - - test_ppl = float(match.group(1)) - - if best_ppl is None or test_ppl < best_ppl: - best_ppl = test_ppl - best_config = res_name - - print(res_name, test_ppl, "time :", end_time - start_time) - else: - print("ERROR:", res_stdout, res_stderr) - - print("nb tests :", nb_tests) - print("best config :", best_config) - print("best ppl :", best_ppl) + for param, param_vals in params.items(): + for param_val in param_vals: + res = subprocess.run([venv_python, 'main.py', param, param_val], + stdout=subprocess.PIPE, + stderr=subprocess.PIPE + ) + res_stdout = res.stdout.decode() + res_stderr = res.stderr.decode() + match = re.search(r'test ppl\s+([0-9.]+)', res_stdout) + if match: + test_ppl = float(match.group(1)) + print(param, param_val, test_ppl) + else: + print("ERROR:", res_stdout, res_stderr)