From 4409bf57fce0c88fd43f31b023b8ff0bc3873cfc Mon Sep 17 00:00:00 2001
From: "lucien.noel" <lucien.noel@etu.hesge.ch>
Date: Wed, 3 Jul 2024 18:35:11 +0200
Subject: [PATCH] =?UTF-8?q?ajout=20d'un=20=C3=A9tape=20de=20test=20lors=20?=
 =?UTF-8?q?de=20l'entrainement?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

---
 main.py                                  |   6 ++++--
 musicgenerator.py                        |   8 +++++++-
 training_data/classical_music/output.mid | Bin 4233 -> 2185 bytes
 3 files changed, 11 insertions(+), 3 deletions(-)

diff --git a/main.py b/main.py
index a7dbe88..68808b5 100644
--- a/main.py
+++ b/main.py
@@ -16,11 +16,11 @@ if __name__ == '__main__':
     learning_rate = 20
     batch_size = 16
     split_train_test_valid = (0.8, 0.1, 0.1)
-    model_path = os.path.join(data_path, 'model.pt')
+    model_path = None  # os.path.join(data_path, 'model.pt')
     sequence_length = 32
     epochs = 100
     nb_log_epoch = 5
-    nb_words = 200
+    nb_words = 100
     temperature = 1.0
 
     print("device :", device)
@@ -30,6 +30,7 @@ if __name__ == '__main__':
     ###############################################################################
     corpus = data.Corpus(data_path, model_path is not None, split_train_test_valid)
     train_data = data.batchify(corpus.train, batch_size, device)
+    test_data = data.batchify(corpus.test, batch_size, device)
     val_data = data.batchify(corpus.valid, batch_size, device)
 
     ###############################################################################
@@ -53,6 +54,7 @@ if __name__ == '__main__':
             criterion=criterion,
             ntokens=ntokens,
             train_data=train_data,
+            test_data=test_data,
             val_data=val_data,
             sequence_length=sequence_length,
             lr=learning_rate,
diff --git a/musicgenerator.py b/musicgenerator.py
index 9d8cde0..6086129 100644
--- a/musicgenerator.py
+++ b/musicgenerator.py
@@ -94,6 +94,7 @@ def train(model: MusicGenerator,
           criterion: nn.NLLLoss,
           ntokens: int,
           train_data: torch.Tensor,
+          test_data: torch.Tensor,
           val_data: torch.Tensor,
           sequence_length: int,
           lr: float,
@@ -123,6 +124,12 @@ def train(model: MusicGenerator,
             else:
                 # Anneal the learning rate if no improvement has been seen in the validation dataset.
                 lr /= 4.0
+        test_loss = evaluate(model, criterion, test_data, ntokens, sequence_length)
+        print('=' * 89)
+        print('| End of training | test loss {:5.2f} | test ppl {:8.2f}'.format(
+            test_loss, math.exp(test_loss)))
+        print('=' * 89)
+
 
     except KeyboardInterrupt:
         print('-' * 89)
@@ -152,7 +159,6 @@ def __train(model: MusicGenerator,
         loss = criterion(output, targets)
         loss.backward()
 
-        # `clip_grad_norm` helps prevent the exploding gradient problem in RNNs / LSTMs.
         clip = 0.25
         torch.nn.utils.clip_grad_norm_(model.parameters(), clip)
         for p in model.parameters():
diff --git a/training_data/classical_music/output.mid b/training_data/classical_music/output.mid
index 3f6efc376cf5c74c47cba84b48c919481dbd2cb3..f426b0b79e84656bd9c37d625b0e1a54512bd457 100644
GIT binary patch
literal 2185
zcmeYb$w*;fU|?flWME=p@C_--W?*0tVfY`&%wx*H@IQiug-L>=gW<nERE{H@;XgA2
z!vh8fh6&zL3=@2#niv|q85kOT85ky5MKMfpjAEGJ9K|rfE2>eu!HR*Q!I6QX!I^=f
z!Ha=mf^8JT1UHZlTLy*(HwK0YHc<={9H2TpKoT|#3=IxY9Ucq}6I7xYCKyIBOfZjv
z7!NYR0i;lcfuX^WfuX^ifdOm{$V3N*&<Pq~g-r|%8W4#ImSBO<21_Kq4afwLfDHq~
z1hXiJJM5wuCb)tf2y%TBLxULuLj%Y?b_@&+t`G-<eF5^KZxq7>|EN%~?|c~;8vGfW
z7$+D-F-&j*8EwSC(BQ-XaW%x)CWuR+Moq8)JBncfND^e6TU2O+1;l|3AT^*s1BDOB
z@u2YW1cxZZ6QJPnga)V`$TmBM7KRCyAeJQq!vyOnh6(miuXsmwFf>>*Ff`ahz2MCN
zQVH^B69YIzTp1W9D8b@U2^x=}_zMA<2a2~4sF_|M0gzdshy=L{6cs*E3=@Ju5}?2Y
zMS%|kLqjkF!vx(Z25@kgL@`Y811ZpDU}yjZiwOfmgC7GZ;WTQ46AdT{fGqR}B^Xd}
zgDmxjx)ZDwWHv|(D6m|jS{T3~<N~z>lrX>v0%QRwF*q<ZK>{A^E|3VwJ+SzK#1+E?
zg(!vzs!<FRT)=*3m=F*Z+MvL|(4fk|(BJ|IhK2wJh6%D!knjPeLnmk?$TBcAfPxB?
z1f9SU!8idV2l9m%I9)V>qaT*+8$iBMjDiFTDCEGgrpUn10CF?PcXkX66F?zl78Tk6
z3LrC3jDn2w02|i=$w40AbOuVCptRQn&Qf5r!BGQ>7W1f3P=)}9JSen4u5tyX0Z_<;
z(iO<LpePRnMM)sY7LdD?qZq*X1{7%^_kp4Wl<HigLK{HNRAzu=Czx|Vu>?wmt_+M5
z3_vCsFhG+O$ZSw90F?wF2Y`Yj2&4k!08nWFa)cKHLqiY)B>erN7$yXP;uI9{eozTe
zh6kAdE-FBo9HbMJ9YDbW_ADqUKpxgZ$wy5L4SJ}V2pq>CMWFZ(h+>!!8pSXn667F|
zeIUi41QEc%&=AVN&=AQ0Nwn4=1t1q&GcZgr02dEU3=Ia5;$VV(R46DvfwK+BXe*HO
zLFoZxoE0?Tf?WxUT97NvK`P7{AgRp@!~rESkPImFHK3ua!N4%V6l@v81aSPALM&_u
zWC#P7fj%H-fr7&anp{9R&=VvGN;9C`=LwDnP|4f`DY9Vs9Tb~zCxN04WH~5;z(EX(
z6j1PjYzL`yi(;4nHW%aqkZLyuh6YcDHb_zI2(l663`Yir3C2(dg9<3$sL%#usGC6r
z6sSr7WlxYj?$Epg%8wxH+`;(<Qc{8Ppa&#S!FkVv0g|6Ug|Qzfw4kXLTyB6et`4-G
z0T~LhxP<|ndv&1o4ajtu3-mzF)q{pDD0PF%TX3}n@*OCOK!q*Ltzd7190>|aQ0#(Z
z0u-}>;FyIJn-jpLBdBBqB~?()1^E+{WI<5^3PMok1&1vtpg_?BG82@HLD2*XQcx)e
z3UiQRc(MjbK$3P7q#y?c0oVngf)`W=fm{I#0#I~;ixN=51@bz`L!hDr<Oo-gw?US{
zvLwjspo|T2SrkYTl)ypW2jymvYoZt+H8<GVAOSaq38CNu(iK!RgOaQ(sB8sQFd!R1
nc^jN0L2d)t1j^U2q=^WRCI)bkhU9;6Eeo#gK|b<e05>lHsv;v$

literal 4233
zcmeYb$w*;fU|?flWME=p@C_--W?*0tVfY`&%znt4;eP}R3zGy#2g84TsGLAJ!+&N5
zh6fA|3=>?V7$*2dH8C`}GB7mwF)&Q<0C7AR7$#UnF-)+FVwm6-#W2AaBxuFJ&|t^F
z(BQ_v(BR9!Fu^&hg`vTjfnkDa6vG6|D3CxCLxU*;LxUv)M8YwOVM1V3Q)q)D14Ba~
z1H%NxD255TQ4AAoK;|hjFf`~gFf`aOFibFrVwhkL5;S07Xs~Bsm;lo18O1Om2qXy7
z?#aN=5X1m+l0y{31W%9z$T<!S3=N(P3=>qN7$)dLU1A@_Fu@6=K$U@^K_BWOdj^IE
zCx#Y=37Sz16SSj38#Eah8nhW0C-_7$Oz@8iZSY}WXz*uXn4k{U4Gju)ur7uMP>_H;
z3UL_Jqfm!Ua0l7t&cHCiJc?n0bri(;{vbhf28IS}1_p2lfXoK@-VEk^knhc)&UXc=
zc4c6g;0<DUGcZiBjbfPK0}5kX1_qGZCK$uQ*ccqXp%aWivPRGt@B%Au)CR|Z7gP(#
zI&W|kw?G2h8yd|Z2~eCxMKMf>2iXX63n<>A7#JGj89?C+3KUQTw1C4G6f)ol0EM$5
zNCzlv4WWJxi3)81c{K#86XKXAh)$SeK&dG>ieW+wNH-|zf*BYZVi+LNW(N`kMU@=`
zG`)e+7%ag!fMj6_#(@En-oQZyax*B{K%RkwSQ8`x!a@x$2onHnb7km)Bm_`w*+(^M
zgHr=2ra-9=6wxjqD?zE*g#i+$U^jqNfb0jSOHg_Q1*BmV!vqtMa!?=|GB7llFhC6V
zfMfs$u;CuiOrR14O<F1pkYuGD1xh%f4cZJ0pg^2p6a`HMMhpxMu#^DGs37A&Aq&p1
zAn$-OEy!3<2!k>&$TQ&d19B!vC&*=><OT~JP;|#eg*Jd32}*gO&;uC+iuG6qMsWHC
z<u2zah6!FEAAuqjl*61E7#h5w*$-SqfZ`97CqRw`rFw`CP?~9CXaG4Jq!X$Oly)ZQ
zL_w0YLsX*{yvP7K7?dnrq8J-=7#JErZgF5}tOqAzkmEoe0a*o#1W-_eJOj?zpr8i%
zz=@%aVS;QFs1#~qXpm(9mp~I-!0rbnB!&rLAb+?(+z%?a8p5E77F5E4N(zwE13`*F
zP6m~7pppXQbWo9H1Iu(a;7kW8fWScmauz5^AZawT0g^l!CxCQ-g9xMp6hx{ZvsK}#
z2vp#L%SA}Z0190{XwvYC>VO2ZOccWejVOi*;6y0Hz|f$<z|a5+15nxI3r<xak1<RL
z2bm2jqkJJofIQa_4lbRcF$nSq$eG~e3rf16#2o-GT0_CX8~`m@LE@lv5&|k!KnWid
z{-6vE@{l_$(SkhX4o$f7Q4AB5z=bI&MS~J8DE^^-lV@OPPy&~*;8YGu$e?u45Xdlr
zVS-*112~z3qQDvC1w95xssRN*s7^5iIRaF=fK-B9<p@&+a+M=P=mbzo03{zt@z%u9
z07?y@<N_<=Kng%Q{Xl^Vk_71l<q}XiY612S$jOb`4WRPW0^%o-GbfmW>@sBl)l;A%
zxCvZGfs0_RD2UrZg^EK|XoD65Lj%a^kaP=5gwQwxg&!!sK-mN0$qrNxf+RrB2iKyY
z`~q@1$mbx{pi~J?<)GjIrA8-)1rXPO%Fhsx^FZ+cDnCOQnm~0OBsD~WWSkfnz^NdT
zfdO1)fNTNf=?IV{$g3cmK&dVQssNOIK{kRDEl2?<AAoEHC0<afa0BTC`N$2Lb3k4P
zDRN+7oB+}Z%5EN@7z617Wi^mac%=x+;7;I52%0oNIS1rQSh5EdfZ*C6R04v^Ur>$&
zTMA0NAWKz1ai{|FDkykBO%gAVb)et@HAg_{8<bGYK@J7QhdH!T1SMTijSbF6ARmB&
z4ODBx;s%r!KnlE~LcygL$on99kcprS2yuH8Bm+X-Js~uzjiDhF+DZYJw;(N`@)l%N
zSQMne11e@f$qW?s;IIHCF;KXJ8VaBggXXDFP{wY6<S9l(`zs9OL68}s^2HM5F;D?x
z3H2B#7(i|Vmlz<Afr0|$K9Ivf5em-kkT_zP07|wXr`Upw0i{`xQ*0R+Ca6IhOrUTA
z*T8Df#u6x;>=^nWDF~E{yg=<ZQ0RbikryZ?Kz;>z!#|2)LIB8MP#}OJ-=Be@A%KBl
z0;~{=21&vSv1kT{3Chr>T6h%0gk+GUGPIc%&cM)+49$0-$_`{SIN5@n1FGmiMuU>A
z9w<HNfl3lkP6VX~50Lvn*$|W-K*<jj2tFXgL4n`{ZPtMjD=2Y!fb@fkO;CabB{Gl>
zP?&;?8Bka{GC)cwP`wLo{(|x-sJ;a?OTf_qiVCm=Aj3hn!5e&_h91Z!P+bky3$g`N
zM}zc&YB5l7gA_P{;uutuf&2$j<ir4JAUlFuBp?ZpGeBMh1vWTGfV>C_Xi%=Ohzf16
zU;s7ZAq{wN)dNxqiWX4#g@9XI3=<NfdLU&ONGB+YKw%jIX?-;$FhJuk02HgRQqvRU
zG+3z#YF~p20#FEn8{V)K2a*6~6Oi{EKw%21r$7M?s%YSW18QLVK^xe943HWU60V?P
z3tGHF!xmI@O#o#XkRL!13+e@cO#)>dklR5K4C)Plf*sUvFoC9FNOlTs0Cgcupvf4L
zs~EtQ1t=AR+kP-#ftrn=QUg>5fIMIY%GRLL1C-1`p0HvFoe&Te+7Q6NH~~~fft>*=
zqd-=J+8m&s1h_8)YIlIT6QI5fOd=B0&;d2}K#uf^Vgxr%K#C$k4IEJ859Cr%eg%0K
zT=;^#2uq2uGy*PPKq(TIMnL5Z$hjb0ptKBbjDVaA(hW+>AR|F`f?^EhX|T~CyFjr9
z%Dx~gVDSvr4GR&NPOzUqIzavb*#(MtP=W<jwxCQI6a{U1fb0c14wQI76)z}*1~D*z
zn;;Oshc-aG$_VaJ!fQ29N(Z$cK;a6`BcS#IC~QG_1mt>qPyz(G0aPJ?qRAU10E!+^
z00e__b}&N^!vs*d1_}yyQ0#%qHBc~sidaw$3-UhLN>J_vc^zaeND!n0k}Mmwz{w31
zl91%a2=2dvOaXNW{Xmfg;($UHn)JYJSy1T<E-gUqR#3?c%Bm(&p$#SsjNoz=Vt)%m
z1E@@e*x#r%!8$6m!J2^)-1-5<0jPF{RF2SgGbk=VH8iw>1b2<WX#f-!Ag_Zux!_C)
z>MemXASj+e5d<#4L2(R<5KuP?)cOZEY(dTaKu{9`ls3VU3ret{$OVN3D5ZgFbWr$!
l5*aA(f+8E7Za@(WN-v;z0=X69?FpcsGo<|r@j7^T0RU-aE9n3L

-- 
GitLab