From 172c3de31d27e261d4f6fb14eb646a742f3c6303 Mon Sep 17 00:00:00 2001
From: "lucien.noel" <noellucien2001@gmail.com>
Date: Mon, 15 Jul 2024 20:57:54 +0200
Subject: [PATCH] =?UTF-8?q?script=20pour=20trouver=20les=20meilleurs=20par?=
 =?UTF-8?q?am=C3=A8tres=20du=20code?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

---
 main.py                                  |  30 +++++++----
 parameter_finder.py                      |  61 +++++++++++++++++++++++
 test.py                                  |   5 +-
 training_data/classical_music/output.mid | Bin 4120 -> 2427 bytes
 4 files changed, 85 insertions(+), 11 deletions(-)
 create mode 100644 parameter_finder.py

diff --git a/main.py b/main.py
index 2939f46..4309779 100644
--- a/main.py
+++ b/main.py
@@ -1,6 +1,7 @@
 import os
 import torch
 import torch.nn as nn
+import argparse
 import numpy as np
 
 import data
@@ -9,21 +10,32 @@ import musicgenerator as mg
 device = "cuda" if torch.cuda.is_available() else "cpu"
 
 if __name__ == '__main__':
+    parser = argparse.ArgumentParser()
+    parser.add_argument('--lr', type=float, default=20, help='initial learning rate')
+    parser.add_argument('--epochs', type=int, default=100, help='upper epoch limit')
+    parser.add_argument('--batch_size', type=int, default=16, help='batch size')
+    parser.add_argument('--sequence_length', type=int, default=32, help='sequence length')
+    parser.add_argument('--dimension_model', type=int, default=128, help='size of word embeddings')
+    parser.add_argument('--nhead', type=int, default=4,  help='the number of heads in the encoder/decoder of the transformer model')
+    parser.add_argument('--dropout', type=float, default=0.2, help='dropout applied to layers (0 = no dropout)')
+    parser.add_argument('--nhid', type=int, default=200, help='number of hidden units per layer')
+    parser.add_argument('--nlayers', type=int, default=2, help='number of layers')
+    args = parser.parse_args()
     ###############################################################################
     # Paramètres
     ###############################################################################
     data_path = os.path.normpath('./training_data/classical_music/')
-    learning_rate = 20
-    batch_size = 16
+    learning_rate = args.lr
+    batch_size = args.batch_size
     split_train_test_valid = (0.8, 0.1, 0.1)
     model_path = None  # os.path.join(data_path, 'model.pt')
-    sequence_length = 32
-    dim_model = 128
-    num_head = 8
-    num_layers = 2
-    num_hid = 200
-    dropout = 0.2
-    epochs = 100
+    sequence_length = args.sequence_length
+    dim_model = args.dimension_model
+    num_head = args.nhead
+    num_layers = args.nlayers
+    num_hid = args.nhid
+    dropout = args.dropout
+    epochs = args.epochs
     nb_log_epoch = 5
     nb_words = 100
     temperature = 1.0
diff --git a/parameter_finder.py b/parameter_finder.py
new file mode 100644
index 0000000..d98fc51
--- /dev/null
+++ b/parameter_finder.py
@@ -0,0 +1,61 @@
+from matplotlib import pyplot as plt
+import subprocess
+import itertools
+import time
+import re
+
+lr = ['0.1', '1', '100']
+epochs = ['20', '50', '100']
+batch_size = ['4', '16', '64']
+sequence_length = ['8', '32', '128']
+dimension_model = ['64', '256', '512']
+nhead = ['2', '4', '8']
+dropout = ['0.0', '0.3', '0.6']
+nhid = ['100', '200', '500']
+nlayers = ['2', '6', '10']
+
+if __name__ == '__main__':
+    best_ppl = None
+    best_config = ''
+    nb_tests = 0
+    for lr_i, epoch_i, batch_size_i, sequence_length_i, dimension_model_i, nhead_i, dropout_i, nhid_i, nlayers_i in itertools.product(
+            lr, epochs, batch_size, sequence_length, dimension_model, nhead, dropout, nhid, nlayers):
+        res_name = f"lr{str.replace(lr_i, '.', '')}_epoch{epoch_i}_batch{batch_size_i}_seq{sequence_length_i}_dim{dimension_model_i}_nhead{nhead_i}_drop{str.replace(dropout_i, '.', '')}_nhid{nhid_i}_nlay{nlayers_i}"
+        nb_tests += 1
+        print("Start config :", res_name)
+        start_time = time.time()
+        res = subprocess.run(['python', 'main.py',
+                              '--lr', lr_i,
+                              '--epochs', epoch_i,
+                              '--batch_size', batch_size_i,
+                              '--sequence_length', sequence_length_i,
+                              '--dimension_model', dimension_model_i,
+                              '--nhead', nhead_i,
+                              '--dropout', dropout_i,
+                              '--nhid', nhid_i,
+                              '--nlayers', nlayers_i
+                              ],
+                             stdout=subprocess.PIPE,
+                             stderr=subprocess.PIPE
+                             )
+        end_time = time.time()
+        res_stdout = res.stdout.decode()
+        res_stderr = res.stderr.decode()
+        match = re.search(r'test ppl\s+([0-9.]+)', res_stdout)
+        if match:
+
+            test_ppl = float(match.group(1))
+
+            if best_ppl is None or test_ppl < best_ppl:
+                best_ppl = test_ppl
+                best_config = res_name
+
+            print(res_name, test_ppl, "time :", end_time - start_time)
+        else:
+            print("ERROR:", res_stdout, res_stderr)
+        if nb_tests >= 15:
+            break
+
+    print("nb tests :", nb_tests)
+    print("best config :", best_config)
+    print("best ppl :", best_ppl)
diff --git a/test.py b/test.py
index 5f6a8e0..36647c2 100644
--- a/test.py
+++ b/test.py
@@ -1,4 +1,5 @@
 import argparse
+import math
 import random
 
 if __name__ == '__main__':
@@ -9,10 +10,10 @@ if __name__ == '__main__':
     parser.add_argument('--batch_size', type=int, default=16, help='batch size')
     parser.add_argument('--sequence_length', type=int, default=32, help='sequence length')
     parser.add_argument('--dimension_model', type=int, default=128, help='size of word embeddings')
-
     parser.add_argument('--nhead', type=int, default=4,  help='the number of heads in the encoder/decoder of the transformer model')
     parser.add_argument('--dropout', type=float, default=0.2, help='dropout applied to layers (0 = no dropout)')
     parser.add_argument('--nhid', type=int, default=200, help='number of hidden units per layer')
     parser.add_argument('--nlayers', type=int, default=2, help='number of layers')
 
-
+    test_loss = random.uniform(0.0, 15.0)
+    print('| End of training | test loss {:5.2f} | test ppl {:8.2f}'.format(test_loss, math.exp(test_loss)))
diff --git a/training_data/classical_music/output.mid b/training_data/classical_music/output.mid
index d809b8e8481dec1ff304d4fe97dea41c32ae0c64..dce21e73e0425bb7ebbb5322f415d2394683aaf7 100644
GIT binary patch
literal 2427
zcmeYb$w*;fU|?flWMEQH@C_--W?*0tVfY`&%(5?$;eP}R3zG!LxrG1v3=la^Plo@@
z3=9t#92h2;M=?zBh+>%F9mO!gKWcJBgE<33g9igcgEs?1gFgeq1ji_b37#McM+Sxl
zPX>kweo+h)LZTQZgo6bA7#JEt7#JGD85kz0L@`V-0ST%wFf^DjK+N<534qM<V_=wI
z9K|rfB8p*xYZSu-caVfJ14Dxa14Dx=14Dy51H%Meun`OsY`}`nB{b+ljB2ofD4Jjy
z1(CFiVwm6%#W2AuieW-nRGUYGB?ALkksSj=g98IYgBJrsLl^_Y1hXiH2`*sk7$*2c
zF-!;mS!>3?(BJ~GvcZRep&@{QVS-T<#C;H}+dLYK7#JEr?t@tU$YO#OSo^tz1}h{j
z&gkI>(&mg7hQ1&d`Z6#~Py*Sm1hvgHiUI5&50D^8n<)bW*dHDY3==>e^ae?QJmk%=
zn`444)Fw}`GkOIYY@xPzLY)Ir-~{z9D3ZWl1u1fZdKVN~Ag^jfK|&f7c>dr3oE*`h
z!N33xZBW4ZLjw;M!k~Z)j$)V)3UUxEbU^_Z%)rnP%D?~)8jy>fK?*>@19G)91H%N(
zsO1I?not!Wmj}UAfLtBKz%W5IieZ8+NKlo50pulc8URIxBS;Vw|Dd38WMG(}1rpGL
zCL>T}fIJKee{eW~q6%a)$itxU_XgXJh<tEpXhK869U5ey#0?4skOzIE7$yWpF-(XA
zMHVQrgF*u2NnZwrhCl{}hDZj637|LwSq;(@1PeM)yn(C+X$XP_ok0}C1YeMY4Hy_2
zd>JN3fD@P<NCe~`JCKMDG&^`i^$IlTK(hiU@<EB#6{HlDP+b`qCg?>mOaO<R9s@%I
zDAd4N)gNRCC<9wV6FA6K{vg9Znb;bdz(KA8Wn_@ELCFJ@`P`yrDS(qPC~1H)pBpr%
zgMu>z7671N41p$Ac#a3T9+u;!p()q~WT!MV_1ZvFFUT%%?gym>kUgLh02F?p;wAuG
zh)j-X0EHc>xCww1B;X7I%7S1I!ZJ0;`Jl1`9?_u41-U2;6i6V~K(hMehz2u;HV<%e
z1o^}bBn^rgkWWBf1f@|>qK3N*l$b&8asZ`M2L?!nv;k>@+6^xMK;a6u3YIKD)`7wi
zl%hdt2b4L&q8KJbfOLYw6O_6^X$O=+!WbADA{Zb=Eyz{gAiF>@1acL~C!pvCdkqxb
zAOTP)*+b(LECC87duV)u6c|DS9+U$>;S6!h<cJ1CXe5Ag0VvF&ZUGmlAdiE~8c>M_
z@;E4|8H1wE7+O$*WkHSr$-okpGt9Z51O>{|pfVU{sTat@prC?T=EcA;0pxU$N5OFj
zayH1Lpg06M+Y^-ZK+Xi!B_LzmK*<fH9TY+8;8J+GL4!J^^aWQJ;F17TO@K-PYmjrS
zp~)4TsX>tfD%L>Og3=b)O(1JQ=?df?ZD>9N#Xcy<gX#m1sNEb5+R!`)N}(WAKy?Dh
z>!ADuj&G2cL2(UAeC8mFL8Xj21E}_b)IFd=!w)0@Dp5c=4OCcw5){ZIp!^Ne399ly
zY1$s75acmXUI&>3@)IaB>=_`n894kAF#{`CCr5xQB2YrMgBE{oumS;Oj2pC2FaZ~(
zpz3wG0jL576{n!8bpog`0%d=2K!OS+kWrw(1Vtq{lYpWS6#k$H0VO6-IEH|n3yK&}
zA_9das4M`bICogt21;chkAtc<Q0f9jdLT#_$Pb`M4}``OxG)2SFsLvCMF+TO1O-1x
zD=2ItK?w;I7NAmA8^qCu1~Vvs`GXP-$R<#pf~6#ocR`61+&TfJC=&(-aPtL}vh=~Z
zkzqnCNT)s|2R6hqFiZf&Ehr&?@-(=t1jR8Z0f1r;WHiVV;EE390Z`=z3M7yd!Lb7h
MD3CKf7{CoG0M%fLmjD0&

literal 4120
zcmeYb$w*;fU|?flWMEQH@C_--W?*0tVfY`&%<A-q;eP}R3zG!LxrG1v3=lc~XAJ+D
z85kZgI512wiDH;w9>p-hE{b7-f7Eh=1``H`26G0620I3Z27d;I32spg6Fj0ACIm$>
zObChU6=-l{U}*4QU}y+pU}y+oV3=SX#W2A)YH~z_H3LI~F9Xadh6#aDZ5}XV8Uh&@
zCg?>mOfZdNm|y|5&o_!;LRi!+g$6wah6Ymxh6W3$y}k?#4Pgun6V##@CImz=Ob7<q
ztH!|45Wv9D5X`_Z!7vKqFn^G}h71f~M}b@m7638KK{DnHAiW@W&r$&C1H0D+?2*Y4
z4K7d*K)gOV0_N5U_E8KIyg^}M&%n^&%>Z#U*kq6+L8kjhF-(XC+2GH>&=AeQFu?)D
zae#&f$Uqm6eo%;j!qJ6+VS;)T!vrIcDs={i1|tSgpiYhe1>^+JD254MQ4AA&K(d|;
z3=Lik3=KXE3=_1WaRZjnhQ<p>0u+Fb;CPuF0S+`rNSsV?00-FQhz18pU`=oXIl~QZ
z*5rr=s5#(hb4HFjXK3`9L1V}@ieZ8~I3y-VG?+nS$CZJh!5tbZmQf57Y`}pDipDmN
z21^Ep1{+AAf}_t0oTi`&!3vU|zzIPmieZ8Q$hRsC3=IYhh*W3`5`?8dTLy*+AjkMb
zF--6Sr2vp4d>9xS{1`w1)++!CD6k8BLB0gp3`%~W*mDJ$1ByLYXzaOw1VFI|iVBdU
z!74ynKq`VkvY>PY$_chn3==>B1G3g1l!QUK!Ipuc0Te(W>p=+s6vdzz0OgK2kZw>k
zgJJ=cJK`8XCQOb1>6qXKk_E*8LJKshL6uJcTLLl^<StMQI>WpJibZDzNGOBTAIQNV
zm3AOQ?HG1*OmG4R&g6&&CuqQcyZ{pP04V`E9VFqwz%ao$ieUoC$<Cl0XUxFR0J0xs
zI;aS6j@r!u&iWt$xRbp>W`HsPG<i>s0GAUWH+q6X2IOgw8$B5yg$O9Zy`#=0fC~^%
zgoE^gj0c4!BqK9S2n87miepe%LNjzjC<7=(KuqxiDKLO0cTik|ayvK=fO0x0{d<8_
zctN!oL9<T?xHRb%03|3;1_PxIP=*I5MKw^6s(}k9Na#<11|h=)O>i2A6>!T98Z;s4
z8(zR!fo!vaCL~Z%=miQvP+|fVgCJ)>g9oV?fSYatve*QiU?)d_oP;PURHGOs=tMD0
z02N3+px{(xU}(@`U;r0PAaBBr1;r^iE<o`K3R6%#g3DY`9D))(C@(?;dj-Jx5GDai
zQ(&#2qyy3lQfUXW4U}8$pd}VKae+!JP~rlWy1G#e6F?aOoNquSuPy^a11J-~@((Bt
zc!2x{G8dEvKuHXgZNQd+vI@vDP~3wAT|wCw6z3obP>ur?E}%r=2P?2Z1q&!a_<;+q
zrrnS-Ap#WiAbk-G%MB)&f}Cy204h)>M}P|wP@V!Mc2AHxP-+Avc2H_mih@)(p!^9c
zTZ2Fflo%MmRShVQg38z+21xZ12Qn9w5<wXkWTZ1lHOM$nZUCh(h#%TK!08I+2Y9Ig
zk_8vEAQ@Oz4+D7_WCo}r1J!z<Xa;9WP|5+-J)jT+Nr2)Fl&ruZ2a*KE9VltRLJ%A~
zpiltC4#=1gko!PJfa+g(QUlkwpn4YM1#lJtSr1B|pdu0ENVqMaq7md+m@S}G0xn`f
zu?fm@ph_KV1SnB}i~yA$HXv(3C5H_I!vs)X2bt*wN(7+14l)m<5~8(N0Nfq`34j6+
z>_{U>^$sr7L9qt%3aEwwXLwNI4vIaHw?H)v$Q7V!HV~BAK-DbBXP{&Q4kl1?0R<1(
zTswxz5#S;l6x-h5wo2n%2T<t?ZWnk%8Y|%R3`(|;`nFd9oSs2R7h30n%QkSm017uy
zDkVA<fzmlBDnS9_4~lqDY=NQ@6eyrr1KA6XMv#3V0Z^cTvZ4nl?STRYlpR4S0@O|b
z1vbc);D7<O7eIjyawRBGKyd|5N1(U@=>Qq!4vHF(G42eYc3hhWxb0>D&BgX0Sx|`z
z%E9)~mY5YN;y^7gP|*fb3CfS)LItD}lovq-3n(^0aRinCB?eI3fFwYT6HtQ=<W{gd
zK@AI#aUi#X+zBc|K$7uM3=>j7kq9b6K$7tc3=Js^3=<%&0BBVSZ3%!P4c_hnmwzB*
zLB$@ZaRn9tF+fEVC`iHi9^@%du!8bEG(9j(kd0!P0CIsP$m_BU3=JULEg3*f)#V1@
zW-2K3L1785?LeUq3Qth(1w|xC0+jl~LE1qP36cb*!f@oq9;jjhWj1h00IHlo2?x}U
z0A)B(LIh`2P-X+AK~QD|IUE$u;HEan;h=B^r3Y}{vH>MzP+<@D7s%HjKWT&9pbc$8
zf|`CHD;+`AEGTKf+RTp7>Ja8FS5T0HViM*x)C9fUpaH2k1m_-b8UiI)P#S^<HOxG)
zS72s=yaEquNJ%g`q5)b8AgWeS_<*to$P7>-0(%P-PN3`oG7FTLKwbpJ7N{Nu1rMlX
z0OvwbtbyucP%wc?29O6psSXq`fzU7kC3A310BR_LJpd|4Kpp@Y15TG9t)TP>Dmy@#
z3ZxxeuYk%BP^JP!YXow%f-(cBs02k6INgJS5mZ!yA`6uEK@k8-exRBKRD*?qTmy=G
zP%;G7ETCEpRPMvW2vo+qf-)SaYzKJ+Bmk~6Kmwpt2g)jzAQhl20!pY-QIHN$63B2V
z28ISuHJk)($bk}x8>r|2sRkt!kP+at4$=w=Z%`5hSpxDl*b|`e1|>z1g&?nkJOWY#
zDuqF9vk;JRAO)Zj7*rdAA{mqyKuHc%41jzc3TtG6@&qX9fr<o>&qJYIHc(*%iVK*A
z0FdiJ1rsP<V2T5vjVX}(z%2}rX&`rjS{ES4gUfi3gF$6HC}Du21SA2j$3ck#6eS=9
zpgJ7n3vh-4xgIGqA+^UqsSuR7K&dbo6w06!2ufI>6bQ<(pilrOaggspJ_m&`$P*yL
zL2(R9Hc_x{A1K5@o&i}1ifd33ih_0p^`ju&7Lap6n%zOx=`%2Z`%oY!gEWCsC`f@b
zsImYFf?68z)B>^>Tw8$b0C@=Bn+5p;oQ^>52c;uWz<>&?a9DN+6;$C2;NCs}JF;I8

-- 
GitLab