From 172c3de31d27e261d4f6fb14eb646a742f3c6303 Mon Sep 17 00:00:00 2001 From: "lucien.noel" <noellucien2001@gmail.com> Date: Mon, 15 Jul 2024 20:57:54 +0200 Subject: [PATCH] =?UTF-8?q?script=20pour=20trouver=20les=20meilleurs=20par?= =?UTF-8?q?am=C3=A8tres=20du=20code?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- main.py | 30 +++++++---- parameter_finder.py | 61 +++++++++++++++++++++++ test.py | 5 +- training_data/classical_music/output.mid | Bin 4120 -> 2427 bytes 4 files changed, 85 insertions(+), 11 deletions(-) create mode 100644 parameter_finder.py diff --git a/main.py b/main.py index 2939f46..4309779 100644 --- a/main.py +++ b/main.py @@ -1,6 +1,7 @@ import os import torch import torch.nn as nn +import argparse import numpy as np import data @@ -9,21 +10,32 @@ import musicgenerator as mg device = "cuda" if torch.cuda.is_available() else "cpu" if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('--lr', type=float, default=20, help='initial learning rate') + parser.add_argument('--epochs', type=int, default=100, help='upper epoch limit') + parser.add_argument('--batch_size', type=int, default=16, help='batch size') + parser.add_argument('--sequence_length', type=int, default=32, help='sequence length') + parser.add_argument('--dimension_model', type=int, default=128, help='size of word embeddings') + parser.add_argument('--nhead', type=int, default=4, help='the number of heads in the encoder/decoder of the transformer model') + parser.add_argument('--dropout', type=float, default=0.2, help='dropout applied to layers (0 = no dropout)') + parser.add_argument('--nhid', type=int, default=200, help='number of hidden units per layer') + parser.add_argument('--nlayers', type=int, default=2, help='number of layers') + args = parser.parse_args() ############################################################################### # Paramètres ############################################################################### data_path = os.path.normpath('./training_data/classical_music/') - learning_rate = 20 - batch_size = 16 + learning_rate = args.lr + batch_size = args.batch_size split_train_test_valid = (0.8, 0.1, 0.1) model_path = None # os.path.join(data_path, 'model.pt') - sequence_length = 32 - dim_model = 128 - num_head = 8 - num_layers = 2 - num_hid = 200 - dropout = 0.2 - epochs = 100 + sequence_length = args.sequence_length + dim_model = args.dimension_model + num_head = args.nhead + num_layers = args.nlayers + num_hid = args.nhid + dropout = args.dropout + epochs = args.epochs nb_log_epoch = 5 nb_words = 100 temperature = 1.0 diff --git a/parameter_finder.py b/parameter_finder.py new file mode 100644 index 0000000..d98fc51 --- /dev/null +++ b/parameter_finder.py @@ -0,0 +1,61 @@ +from matplotlib import pyplot as plt +import subprocess +import itertools +import time +import re + +lr = ['0.1', '1', '100'] +epochs = ['20', '50', '100'] +batch_size = ['4', '16', '64'] +sequence_length = ['8', '32', '128'] +dimension_model = ['64', '256', '512'] +nhead = ['2', '4', '8'] +dropout = ['0.0', '0.3', '0.6'] +nhid = ['100', '200', '500'] +nlayers = ['2', '6', '10'] + +if __name__ == '__main__': + best_ppl = None + best_config = '' + nb_tests = 0 + for lr_i, epoch_i, batch_size_i, sequence_length_i, dimension_model_i, nhead_i, dropout_i, nhid_i, nlayers_i in itertools.product( + lr, epochs, batch_size, sequence_length, dimension_model, nhead, dropout, nhid, nlayers): + res_name = f"lr{str.replace(lr_i, '.', '')}_epoch{epoch_i}_batch{batch_size_i}_seq{sequence_length_i}_dim{dimension_model_i}_nhead{nhead_i}_drop{str.replace(dropout_i, '.', '')}_nhid{nhid_i}_nlay{nlayers_i}" + nb_tests += 1 + print("Start config :", res_name) + start_time = time.time() + res = subprocess.run(['python', 'main.py', + '--lr', lr_i, + '--epochs', epoch_i, + '--batch_size', batch_size_i, + '--sequence_length', sequence_length_i, + '--dimension_model', dimension_model_i, + '--nhead', nhead_i, + '--dropout', dropout_i, + '--nhid', nhid_i, + '--nlayers', nlayers_i + ], + stdout=subprocess.PIPE, + stderr=subprocess.PIPE + ) + end_time = time.time() + res_stdout = res.stdout.decode() + res_stderr = res.stderr.decode() + match = re.search(r'test ppl\s+([0-9.]+)', res_stdout) + if match: + + test_ppl = float(match.group(1)) + + if best_ppl is None or test_ppl < best_ppl: + best_ppl = test_ppl + best_config = res_name + + print(res_name, test_ppl, "time :", end_time - start_time) + else: + print("ERROR:", res_stdout, res_stderr) + if nb_tests >= 15: + break + + print("nb tests :", nb_tests) + print("best config :", best_config) + print("best ppl :", best_ppl) diff --git a/test.py b/test.py index 5f6a8e0..36647c2 100644 --- a/test.py +++ b/test.py @@ -1,4 +1,5 @@ import argparse +import math import random if __name__ == '__main__': @@ -9,10 +10,10 @@ if __name__ == '__main__': parser.add_argument('--batch_size', type=int, default=16, help='batch size') parser.add_argument('--sequence_length', type=int, default=32, help='sequence length') parser.add_argument('--dimension_model', type=int, default=128, help='size of word embeddings') - parser.add_argument('--nhead', type=int, default=4, help='the number of heads in the encoder/decoder of the transformer model') parser.add_argument('--dropout', type=float, default=0.2, help='dropout applied to layers (0 = no dropout)') parser.add_argument('--nhid', type=int, default=200, help='number of hidden units per layer') parser.add_argument('--nlayers', type=int, default=2, help='number of layers') - + test_loss = random.uniform(0.0, 15.0) + print('| End of training | test loss {:5.2f} | test ppl {:8.2f}'.format(test_loss, math.exp(test_loss))) diff --git a/training_data/classical_music/output.mid b/training_data/classical_music/output.mid index d809b8e8481dec1ff304d4fe97dea41c32ae0c64..dce21e73e0425bb7ebbb5322f415d2394683aaf7 100644 GIT binary patch literal 2427 zcmeYb$w*;fU|?flWMEQH@C_--W?*0tVfY`&%(5?$;eP}R3zG!LxrG1v3=la^Plo@@ z3=9t#92h2;M=?zBh+>%F9mO!gKWcJBgE<33g9igcgEs?1gFgeq1ji_b37#McM+Sxl zPX>kweo+h)LZTQZgo6bA7#JEt7#JGD85kz0L@`V-0ST%wFf^DjK+N<534qM<V_=wI z9K|rfB8p*xYZSu-caVfJ14Dxa14Dx=14Dy51H%Meun`OsY`}`nB{b+ljB2ofD4Jjy z1(CFiVwm6%#W2AuieW-nRGUYGB?ALkksSj=g98IYgBJrsLl^_Y1hXiH2`*sk7$*2c zF-!;mS!>3?(BJ~GvcZRep&@{QVS-T<#C;H}+dLYK7#JEr?t@tU$YO#OSo^tz1}h{j z&gkI>(&mg7hQ1&d`Z6#~Py*Sm1hvgHiUI5&50D^8n<)bW*dHDY3==>e^ae?QJmk%= zn`444)Fw}`GkOIYY@xPzLY)Ir-~{z9D3ZWl1u1fZdKVN~Ag^jfK|&f7c>dr3oE*`h z!N33xZBW4ZLjw;M!k~Z)j$)V)3UUxEbU^_Z%)rnP%D?~)8jy>fK?*>@19G)91H%N( zsO1I?not!Wmj}UAfLtBKz%W5IieZ8+NKlo50pulc8URIxBS;Vw|Dd38WMG(}1rpGL zCL>T}fIJKee{eW~q6%a)$itxU_XgXJh<tEpXhK869U5ey#0?4skOzIE7$yWpF-(XA zMHVQrgF*u2NnZwrhCl{}hDZj637|LwSq;(@1PeM)yn(C+X$XP_ok0}C1YeMY4Hy_2 zd>JN3fD@P<NCe~`JCKMDG&^`i^$IlTK(hiU@<EB#6{HlDP+b`qCg?>mOaO<R9s@%I zDAd4N)gNRCC<9wV6FA6K{vg9Znb;bdz(KA8Wn_@ELCFJ@`P`yrDS(qPC~1H)pBpr% zgMu>z7671N41p$Ac#a3T9+u;!p()q~WT!MV_1ZvFFUT%%?gym>kUgLh02F?p;wAuG zh)j-X0EHc>xCww1B;X7I%7S1I!ZJ0;`Jl1`9?_u41-U2;6i6V~K(hMehz2u;HV<%e z1o^}bBn^rgkWWBf1f@|>qK3N*l$b&8asZ`M2L?!nv;k>@+6^xMK;a6u3YIKD)`7wi zl%hdt2b4L&q8KJbfOLYw6O_6^X$O=+!WbADA{Zb=Eyz{gAiF>@1acL~C!pvCdkqxb zAOTP)*+b(LECC87duV)u6c|DS9+U$>;S6!h<cJ1CXe5Ag0VvF&ZUGmlAdiE~8c>M_ z@;E4|8H1wE7+O$*WkHSr$-okpGt9Z51O>{|pfVU{sTat@prC?T=EcA;0pxU$N5OFj zayH1Lpg06M+Y^-ZK+Xi!B_LzmK*<fH9TY+8;8J+GL4!J^^aWQJ;F17TO@K-PYmjrS zp~)4TsX>tfD%L>Og3=b)O(1JQ=?df?ZD>9N#Xcy<gX#m1sNEb5+R!`)N}(WAKy?Dh z>!ADuj&G2cL2(UAeC8mFL8Xj21E}_b)IFd=!w)0@Dp5c=4OCcw5){ZIp!^Ne399ly zY1$s75acmXUI&>3@)IaB>=_`n894kAF#{`CCr5xQB2YrMgBE{oumS;Oj2pC2FaZ~( zpz3wG0jL576{n!8bpog`0%d=2K!OS+kWrw(1Vtq{lYpWS6#k$H0VO6-IEH|n3yK&} zA_9das4M`bICogt21;chkAtc<Q0f9jdLT#_$Pb`M4}``OxG)2SFsLvCMF+TO1O-1x zD=2ItK?w;I7NAmA8^qCu1~Vvs`GXP-$R<#pf~6#ocR`61+&TfJC=&(-aPtL}vh=~Z zkzqnCNT)s|2R6hqFiZf&Ehr&?@-(=t1jR8Z0f1r;WHiVV;EE390Z`=z3M7yd!Lb7h MD3CKf7{CoG0M%fLmjD0& literal 4120 zcmeYb$w*;fU|?flWMEQH@C_--W?*0tVfY`&%<A-q;eP}R3zG!LxrG1v3=lc~XAJ+D z85kZgI512wiDH;w9>p-hE{b7-f7Eh=1``H`26G0620I3Z27d;I32spg6Fj0ACIm$> zObChU6=-l{U}*4QU}y+pU}y+oV3=SX#W2A)YH~z_H3LI~F9Xadh6#aDZ5}XV8Uh&@ zCg?>mOfZdNm|y|5&o_!;LRi!+g$6wah6Ymxh6W3$y}k?#4Pgun6V##@CImz=Ob7<q ztH!|45Wv9D5X`_Z!7vKqFn^G}h71f~M}b@m7638KK{DnHAiW@W&r$&C1H0D+?2*Y4 z4K7d*K)gOV0_N5U_E8KIyg^}M&%n^&%>Z#U*kq6+L8kjhF-(XC+2GH>&=AeQFu?)D zae#&f$Uqm6eo%;j!qJ6+VS;)T!vrIcDs={i1|tSgpiYhe1>^+JD254MQ4AA&K(d|; z3=Lik3=KXE3=_1WaRZjnhQ<p>0u+Fb;CPuF0S+`rNSsV?00-FQhz18pU`=oXIl~QZ z*5rr=s5#(hb4HFjXK3`9L1V}@ieZ8~I3y-VG?+nS$CZJh!5tbZmQf57Y`}pDipDmN z21^Ep1{+AAf}_t0oTi`&!3vU|zzIPmieZ8Q$hRsC3=IYhh*W3`5`?8dTLy*+AjkMb zF--6Sr2vp4d>9xS{1`w1)++!CD6k8BLB0gp3`%~W*mDJ$1ByLYXzaOw1VFI|iVBdU z!74ynKq`VkvY>PY$_chn3==>B1G3g1l!QUK!Ipuc0Te(W>p=+s6vdzz0OgK2kZw>k zgJJ=cJK`8XCQOb1>6qXKk_E*8LJKshL6uJcTLLl^<StMQI>WpJibZDzNGOBTAIQNV zm3AOQ?HG1*OmG4R&g6&&CuqQcyZ{pP04V`E9VFqwz%ao$ieUoC$<Cl0XUxFR0J0xs zI;aS6j@r!u&iWt$xRbp>W`HsPG<i>s0GAUWH+q6X2IOgw8$B5yg$O9Zy`#=0fC~^% zgoE^gj0c4!BqK9S2n87miepe%LNjzjC<7=(KuqxiDKLO0cTik|ayvK=fO0x0{d<8_ zctN!oL9<T?xHRb%03|3;1_PxIP=*I5MKw^6s(}k9Na#<11|h=)O>i2A6>!T98Z;s4 z8(zR!fo!vaCL~Z%=miQvP+|fVgCJ)>g9oV?fSYatve*QiU?)d_oP;PURHGOs=tMD0 z02N3+px{(xU}(@`U;r0PAaBBr1;r^iE<o`K3R6%#g3DY`9D))(C@(?;dj-Jx5GDai zQ(qyy3lQfUXW4U}8$pd}VKae+!JP~rlWy1G#e6F?aOoNquSuPy^a11J-~@((Bt zc!2x{G8dEvKuHXgZNQd+vI@vDP~3wAT|wCw6z3obP>ur?E}%r=2P?2Z1q&!a_<;+q zrrnS-Ap#WiAbk-G%MB)&f}Cy204h)>M}P|wP@V!Mc2AHxP-+Avc2H_mih@)(p!^9c zTZ2Fflo%MmRShVQg38z+21xZ12Qn9w5<wXkWTZ1lHOM$nZUCh(h#%TK!08I+2Y9Ig zk_8vEAQ@Oz4+D7_WCo}r1J!z<Xa;9WP|5+-J)jT+Nr2)Fl&ruZ2a*KE9VltRLJ%A~ zpiltC4#=1gko!PJfa+g(QUlkwpn4YM1#lJtSr1B|pdu0ENVqMaq7md+m@S}G0xn`f zu?fm@ph_KV1SnB}i~yA$HXv(3C5H_I!vs)X2bt*wN(7+14l)m<5~8(N0Nfq`34j6+ z>_{U>^$sr7L9qt%3aEwwXLwNI4vIaHw?H)v$Q7V!HV~BAK-DbBXP{&Q4kl1?0R<1( zTswxz5#S;l6x-h5wo2n%2T<t?ZWnk%8Y|%R3`(|;`nFd9oSs2R7h30n%QkSm017uy zDkVA<fzmlBDnS9_4~lqDY=NQ@6eyrr1KA6XMv#3V0Z^cTvZ4nl?STRYlpR4S0@O|b z1vbc);D7<O7eIjyawRBGKyd|5N1(U@=>Qq!4vHF(G42eYc3hhWxb0>D&BgX0Sx|`z z%E9)~mY5YN;y^7gP|*fb3CfS)LItD}lovq-3n(^0aRinCB?eI3fFwYT6HtQ=<W{gd zK@AI#aUi#X+zBc|K$7uM3=>j7kq9b6K$7tc3=Js^3=<%&0BBVSZ3%!P4c_hnmwzB* zLB$@ZaRn9tF+fEVC`iHi9^@%du!8bEG(9j(kd0!P0CIsP$m_BU3=JULEg3*f)#V1@ zW-2K3L1785?LeUq3Qth(1w|xC0+jl~LE1qP36cb*!f@oq9;jjhWj1h00IHlo2?x}U z0A)B(LIh`2P-X+AK~QD|IUE$u;HEan;h=B^r3Y}{vH>MzP+<@D7s%HjKWT&9pbc$8 zf|`CHD;+`AEGTKf+RTp7>Ja8FS5T0HViM*x)C9fUpaH2k1m_-b8UiI)P#S^<HOxG) zS72s=yaEquNJ%g`q5)b8AgWeS_<*to$P7>-0(%P-PN3`oG7FTLKwbpJ7N{Nu1rMlX z0OvwbtbyucP%wc?29O6psSXq`fzU7kC3A310BR_LJpd|4Kpp@Y15TG9t)TP>Dmy@# z3ZxxeuYk%BP^JP!YXow%f-(cBs02k6INgJS5mZ!yA`6uEK@k8-exRBKRD*?qTmy=G zP%;G7ETCEpRPMvW2vo+qf-)SaYzKJ+Bmk~6Kmwpt2g)jzAQhl20!pY-QIHN$63B2V z28ISuHJk)($bk}x8>r|2sRkt!kP+at4$=w=Z%`5hSpxDl*b|`e1|>z1g&?nkJOWY# zDuqF9vk;JRAO)Zj7*rdAA{mqyKuHc%41jzc3TtG6@&qX9fr<o>&qJYIHc(*%iVK*A z0FdiJ1rsP<V2T5vjVX}(z%2}rX&`rjS{ES4gUfi3gF$6HC}Du21SA2j$3ck#6eS=9 zpgJ7n3vh-4xgIGqA+^UqsSuR7K&dbo6w06!2ufI>6bQ<(pilrOaggspJ_m&`$P*yL zL2(R9Hc_x{A1K5@o&i}1ifd33ih_0p^`ju&7Lap6n%zOx=`%2Z`%oY!gEWCsC`f@b zsImYFf?68z)B>^>Tw8$b0C@=Bn+5p;oQ^>52c;uWz<>&?a9DN+6;$C2;NCs}JF;I8 -- GitLab