From 4409bf57fce0c88fd43f31b023b8ff0bc3873cfc Mon Sep 17 00:00:00 2001 From: "lucien.noel" <lucien.noel@etu.hesge.ch> Date: Wed, 3 Jul 2024 18:35:11 +0200 Subject: [PATCH] =?UTF-8?q?ajout=20d'un=20=C3=A9tape=20de=20test=20lors=20?= =?UTF-8?q?de=20l'entrainement?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- main.py | 6 ++++-- musicgenerator.py | 8 +++++++- training_data/classical_music/output.mid | Bin 4233 -> 2185 bytes 3 files changed, 11 insertions(+), 3 deletions(-) diff --git a/main.py b/main.py index a7dbe88..68808b5 100644 --- a/main.py +++ b/main.py @@ -16,11 +16,11 @@ if __name__ == '__main__': learning_rate = 20 batch_size = 16 split_train_test_valid = (0.8, 0.1, 0.1) - model_path = os.path.join(data_path, 'model.pt') + model_path = None # os.path.join(data_path, 'model.pt') sequence_length = 32 epochs = 100 nb_log_epoch = 5 - nb_words = 200 + nb_words = 100 temperature = 1.0 print("device :", device) @@ -30,6 +30,7 @@ if __name__ == '__main__': ############################################################################### corpus = data.Corpus(data_path, model_path is not None, split_train_test_valid) train_data = data.batchify(corpus.train, batch_size, device) + test_data = data.batchify(corpus.test, batch_size, device) val_data = data.batchify(corpus.valid, batch_size, device) ############################################################################### @@ -53,6 +54,7 @@ if __name__ == '__main__': criterion=criterion, ntokens=ntokens, train_data=train_data, + test_data=test_data, val_data=val_data, sequence_length=sequence_length, lr=learning_rate, diff --git a/musicgenerator.py b/musicgenerator.py index 9d8cde0..6086129 100644 --- a/musicgenerator.py +++ b/musicgenerator.py @@ -94,6 +94,7 @@ def train(model: MusicGenerator, criterion: nn.NLLLoss, ntokens: int, train_data: torch.Tensor, + test_data: torch.Tensor, val_data: torch.Tensor, sequence_length: int, lr: float, @@ -123,6 +124,12 @@ def train(model: MusicGenerator, else: # Anneal the learning rate if no improvement has been seen in the validation dataset. lr /= 4.0 + test_loss = evaluate(model, criterion, test_data, ntokens, sequence_length) + print('=' * 89) + print('| End of training | test loss {:5.2f} | test ppl {:8.2f}'.format( + test_loss, math.exp(test_loss))) + print('=' * 89) + except KeyboardInterrupt: print('-' * 89) @@ -152,7 +159,6 @@ def __train(model: MusicGenerator, loss = criterion(output, targets) loss.backward() - # `clip_grad_norm` helps prevent the exploding gradient problem in RNNs / LSTMs. clip = 0.25 torch.nn.utils.clip_grad_norm_(model.parameters(), clip) for p in model.parameters(): diff --git a/training_data/classical_music/output.mid b/training_data/classical_music/output.mid index 3f6efc376cf5c74c47cba84b48c919481dbd2cb3..f426b0b79e84656bd9c37d625b0e1a54512bd457 100644 GIT binary patch literal 2185 zcmeYb$w*;fU|?flWME=p@C_--W?*0tVfY`&%wx*H@IQiug-L>=gW<nERE{H@;XgA2 z!vh8fh6&zL3=@2#niv|q85kOT85ky5MKMfpjAEGJ9K|rfE2>eu!HR*Q!I6QX!I^=f z!Ha=mf^8JT1UHZlTLy*(HwK0YHc<={9H2TpKoT|#3=IxY9Ucq}6I7xYCKyIBOfZjv z7!NYR0i;lcfuX^WfuX^ifdOm{$V3N*&<Pq~g-r|%8W4#ImSBO<21_Kq4afwLfDHq~ z1hXiJJM5wuCb)tf2y%TBLxULuLj%Y?b_@&+t`G-<eF5^KZxq7>|EN%~?|c~;8vGfW z7$+D-F-&j*8EwSC(BQ-XaW%x)CWuR+Moq8)JBncfND^e6TU2O+1;l|3AT^*s1BDOB z@u2YW1cxZZ6QJPnga)V`$TmBM7KRCyAeJQq!vyOnh6(miuXsmwFf>>*Ff`ahz2MCN zQVH^B69YIzTp1W9D8b@U2^x=}_zMA<2a2~4sF_|M0gzdshy=L{6cs*E3=@Ju5}?2Y zMS%|kLqjkF!vx(Z25@kgL@`Y811ZpDU}yjZiwOfmgC7GZ;WTQ46AdT{fGqR}B^Xd} zgDmxjx)ZDwWHv|(D6m|jS{T3~<N~z>lrX>v0%QRwF*q<ZK>{A^E|3VwJ+SzK#1+E? zg(!vzs!<FRT)=*3m=F*Z+MvL|(4fk|(BJ|IhK2wJh6%D!knjPeLnmk?$TBcAfPxB? z1f9SU!8idV2l9m%I9)V>qaT*+8$iBMjDiFTDCEGgrpUn10CF?PcXkX66F?zl78Tk6 z3LrC3jDn2w02|i=$w40AbOuVCptRQn&Qf5r!BGQ>7W1f3P=)}9JSen4u5tyX0Z_<; z(iO<LpePRnMM)sY7LdD?qZq*X1{7%^_kp4Wl<HigLK{HNRAzu=Czx|Vu>?wmt_+M5 z3_vCsFhG+O$ZSw90F?wF2Y`Yj2&4k!08nWFa)cKHLqiY)B>erN7$yXP;uI9{eozTe zh6kAdE-FBo9HbMJ9YDbW_ADqUKpxgZ$wy5L4SJ}V2pq>CMWFZ(h+>!!8pSXn667F| zeIUi41QEc%&=AVN&=AQ0Nwn4=1t1q&GcZgr02dEU3=Ia5;$VV(R46DvfwK+BXe*HO zLFoZxoE0?Tf?WxUT97NvK`P7{AgRp@!~rESkPImFHK3ua!N4%V6l@v81aSPALM&_u zWC#P7fj%H-fr7&anp{9R&=VvGN;9C`=LwDnP|4f`DY9Vs9Tb~zCxN04WH~5;z(EX( z6j1PjYzL`yi(;4nHW%aqkZLyuh6YcDHb_zI2(l663`Yir3C2(dg9<3$sL%#usGC6r z6sSr7WlxYj?$Epg%8wxH+`;(<Qc{8Ppa&#S!FkVv0g|6Ug|Qzfw4kXLTyB6et`4-G z0T~LhxP<|ndv&1o4ajtu3-mzF)q{pDD0PF%TX3}n@*OCOK!q*Ltzd7190>|aQ0#(Z z0u-}>;FyIJn-jpLBdBBqB~?()1^E+{WI<5^3PMok1&1vtpg_?BG82@HLD2*XQcx)e z3UiQRc(MjbK$3P7q#y?c0oVngf)`W=fm{I#0#I~;ixN=51@bz`L!hDr<Oo-gw?US{ zvLwjspo|T2SrkYTl)ypW2jymvYoZt+H8<GVAOSaq38CNu(iK!RgOaQ(sB8sQFd!R1 nc^jN0L2d)t1j^U2q=^WRCI)bkhU9;6Eeo#gK|b<e05>lHsv;v$ literal 4233 zcmeYb$w*;fU|?flWME=p@C_--W?*0tVfY`&%znt4;eP}R3zGy#2g84TsGLAJ!+&N5 zh6fA|3=>?V7$*2dH8C`}GB7mwF)&Q<0C7AR7$#UnF-)+FVwm6-#W2AaBxuFJ&|t^F z(BQ_v(BR9!Fu^&hg`vTjfnkDa6vG6|D3CxCLxU*;LxUv)M8YwOVM1V3Q)q)D14Ba~ z1H%NxD255TQ4AAoK;|hjFf`~gFf`aOFibFrVwhkL5;S07Xs~Bsm;lo18O1Om2qXy7 z?#aN=5X1m+l0y{31W%9z$T<!S3=N(P3=>qN7$)dLU1A@_Fu@6=K$U@^K_BWOdj^IE zCx#Y=37Sz16SSj38#Eah8nhW0C-_7$Oz@8iZSY}WXz*uXn4k{U4Gju)ur7uMP>_H; z3UL_Jqfm!Ua0l7t&cHCiJc?n0bri(;{vbhf28IS}1_p2lfXoK@-VEk^knhc)&UXc= zc4c6g;0<DUGcZiBjbfPK0}5kX1_qGZCK$uQ*ccqXp%aWivPRGt@B%Au)CR|Z7gP(# zI&W|kw?G2h8yd|Z2~eCxMKMf>2iXX63n<>A7#JGj89?C+3KUQTw1C4G6f)ol0EM$5 zNCzlv4WWJxi3)81c{K#86XKXAh)$SeK&dG>ieW+wNH-|zf*BYZVi+LNW(N`kMU@=` zG`)e+7%ag!fMj6_#(@En-oQZyax*B{K%RkwSQ8`x!a@x$2onHnb7km)Bm_`w*+(^M zgHr=2ra-9=6wxjqD?zE*g#i+$U^jqNfb0jSOHg_Q1*BmV!vqtMa!?=|GB7llFhC6V zfMfs$u;CuiOrR14O<F1pkYuGD1xh%f4cZJ0pg^2p6a`HMMhpxMu#^DGs37A&Aq&p1 zAn$-OEy!3<2!k>&$TQ&d19B!vC&*=><OT~JP;|#eg*Jd32}*gO&;uC+iuG6qMsWHC z<u2zah6!FEAAuqjl*61E7#h5w*$-SqfZ`97CqRw`rFw`CP?~9CXaG4Jq!X$Oly)ZQ zL_w0YLsX*{yvP7K7?dnrq8J-=7#JErZgF5}tOqAzkmEoe0a*o#1W-_eJOj?zpr8i% zz=@%aVS;QFs1#~qXpm(9mp~I-!0rbnB!&rLAb+?(+z%?a8p5E77F5E4N(zwE13`*F zP6m~7pppXQbWo9H1Iu(a;7kW8fWScmauz5^AZawT0g^l!CxCQ-g9xMp6hx{ZvsK}# z2vp#L%SA}Z0190{XwvYC>VO2ZOccWejVOi*;6y0Hz|f$<z|a5+15nxI3r<xak1<RL z2bm2jqkJJofIQa_4lbRcF$nSq$eG~e3rf16#2o-GT0_CX8~`m@LE@lv5&|k!KnWid z{-6vE@{l_$(SkhX4o$f7Q4AB5z=bI&MS~J8DE^^-lV@OPPy&~*;8YGu$e?u45Xdlr zVS-*112~z3qQDvC1w95xssRN*s7^5iIRaF=fK-B9<p@&+a+M=P=mbzo03{zt@z%u9 z07?y@<N_<=Kng%Q{Xl^Vk_71l<q}XiY612S$jOb`4WRPW0^%o-GbfmW>@sBl)l;A% zxCvZGfs0_RD2UrZg^EK|XoD65Lj%a^kaP=5gwQwxg&!!sK-mN0$qrNxf+RrB2iKyY z`~q@1$mbx{pi~J?<)GjIrA8-)1rXPO%Fhsx^FZ+cDnCOQnm~0OBsD~WWSkfnz^NdT zfdO1)fNTNf=?IV{$g3cmK&dVQssNOIK{kRDEl2?<AAoEHC0<afa0BTC`N$2Lb3k4P zDRN+7oB+}Z%5EN@7z617Wi^mac%=x+;7;I52%0oNIS1rQSh5EdfZ*C6R04v^Ur>$& zTMA0NAWKz1ai{|FDkykBO%gAVb)et@HAg_{8<bGYK@J7QhdH!T1SMTijSbF6ARmB& z4ODBx;s%r!KnlE~LcygL$on99kcprS2yuH8Bm+X-Js~uzjiDhF+DZYJw;(N`@)l%N zSQMne11e@f$qW?s;IIHCF;KXJ8VaBggXXDFP{wY6<S9l(`zs9OL68}s^2HM5F;D?x z3H2B#7(i|Vmlz<Afr0|$K9Ivf5em-kkT_zP07|wXr`Upw0i{`xQ*0R+Ca6IhOrUTA z*T8Df#u6x;>=^nWDF~E{yg=<ZQ0RbikryZ?Kz;>z!#|2)LIB8MP#}OJ-=Be@A%KBl z0;~{=21&vSv1kT{3Chr>T6h%0gk+GUGPIc%&cM)+49$0-$_`{SIN5@n1FGmiMuU>A z9w<HNfl3lkP6VX~50Lvn*$|W-K*<jj2tFXgL4n`{ZPtMjD=2Y!fb@fkO;CabB{Gl> zP?&;?8Bka{GC)cwP`wLo{(|x-sJ;a?OTf_qiVCm=Aj3hn!5e&_h91Z!P+bky3$g`N zM}zc&YB5l7gA_P{;uutuf&2$j<ir4JAUlFuBp?ZpGeBMh1vWTGfV>C_Xi%=Ohzf16 zU;s7ZAq{wN)dNxqiWX4#g@9XI3=<NfdLU&ONGB+YKw%jIX?-;$FhJuk02HgRQqvRU zG+3z#YF~p20#FEn8{V)K2a*6~6Oi{EKw%21r$7M?s%YSW18QLVK^xe943HWU60V?P z3tGHF!xmI@O#o#XkRL!13+e@cO#)>dklR5K4C)Plf*sUvFoC9FNOlTs0Cgcupvf4L zs~EtQ1t=AR+kP-#ftrn=QUg>5fIMIY%GRLL1C-1`p0HvFoe&Te+7Q6NH~~~fft>*= zqd-=J+8m&s1h_8)YIlIT6QI5fOd=B0&;d2}K#uf^Vgxr%K#C$k4IEJ859Cr%eg%0K zT=;^#2uq2uGy*PPKq(TIMnL5Z$hjb0ptKBbjDVaA(hW+>AR|F`f?^EhX|T~CyFjr9 z%Dx~gVDSvr4GR&NPOzUqIzavb*#(MtP=W<jwxCQI6a{U1fb0c14wQI76)z}*1~D*z zn;;Oshc-aG$_VaJ!fQ29N(Z$cK;a6`BcS#IC~QG_1mt>qPyz(G0aPJ?qRAU10E!+^ z00e__b}&N^!vs*d1_}yyQ0#%qHBc~sidaw$3-UhLN>J_vc^zaeND!n0k}Mmwz{w31 zl91%a2=2dvOaXNW{Xmfg;($UHn)JYJSy1T<E-gUqR#3?c%Bm(&p$#SsjNoz=Vt)%m z1E@@e*x#r%!8$6m!J2^)-1-5<0jPF{RF2SgGbk=VH8iw>1b2<WX#f-!Ag_Zux!_C) z>MemXASj+e5d<#4L2(R<5KuP?)cOZEY(dTaKu{9`ls3VU3ret{$OVN3D5ZgFbWr$! l5*aA(f+8E7Za@(WN-v;z0=X69?FpcsGo<|r@j7^T0RU-aE9n3L -- GitLab