diff --git a/.gitignore b/.gitignore
index 1725bb66d4a9616fd86844b581cf480f53a5a3e3..6b6302cd73a79f15e0b69ae059c8dfa426990059 100644
--- a/.gitignore
+++ b/.gitignore
@@ -1,3 +1,4 @@
 .idea/
 .venv/
-__pycache__/
\ No newline at end of file
+__pycache__/
+*_old.py
\ No newline at end of file
diff --git a/README.md b/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..e04937df54cb6df2397d195e8210666a9f5853b8
--- /dev/null
+++ b/README.md
@@ -0,0 +1 @@
+Inspiré du projet [suivant](https://github.com/pytorch/examples/tree/main/word_language_model "Exemple pytorch d'utilisation d'un Transformer")
\ No newline at end of file
diff --git a/data.py b/data.py
new file mode 100644
index 0000000000000000000000000000000000000000..18707d80dd8c159e39d75f10d2101164cf00dcd0
--- /dev/null
+++ b/data.py
@@ -0,0 +1,61 @@
+import os
+from io import open
+import torch
+
+
+class Dictionary(object):
+    def __init__(self) -> None:
+        self.word2idx = {}
+        self.idx2word = []
+
+    def add_word(self, word: str) -> int:
+        if word not in self.word2idx:
+            self.idx2word.append(word)
+            self.word2idx[word] = len(self.idx2word) - 1
+        return self.word2idx[word]
+
+    def __len__(self) -> int:
+        return len(self.idx2word)
+
+
+class Corpus(object):
+    def __init__(self, path: str) -> None:
+        self.dictionary = Dictionary()
+        self.train = self.tokenize(os.path.join(path, './training_data/classical_music/test/input.txt'))
+        # self.valid = self.tokenize(os.path.join(path, './training_data/classical_music/test/input.txt'))
+        # self.test = self.tokenize(os.path.join(path, './training_data/classical_music/test/input.txt'))
+        self.valid = self.train
+        self.test = self.train
+
+    def tokenize(self, path: str) -> torch.Tensor:
+        """Tokenizes a text file."""
+        assert os.path.exists(path), print(f"{path} doesn't exists")
+        # Add words to the dictionary
+        with open(path, 'r', encoding="utf8") as f:
+            for line in f:
+                words = line.split() + ['<eos>']
+                for word in words:
+                    self.dictionary.add_word(word)
+
+        # Tokenize file content
+        with open(path, 'r', encoding="utf8") as f:
+            idss = []
+            for line in f:
+                words = line.split() + ['<eos>']
+                ids = []
+                for word in words:
+                    ids.append(self.dictionary.word2idx[word])
+                idss.append(torch.tensor(ids).type(torch.int64))
+            ids = torch.cat(idss)
+
+        return ids
+
+
+def batchify(data: torch.Tensor, bsz: int, device: str = 'cpu') -> torch.Tensor:
+    # Work out how cleanly we can divide the dataset into bsz parts.
+    nbatch = data.size(0) // bsz
+    # Trim off any extra elements that wouldn't cleanly fit (remainders).
+    data = data.narrow(0, 0, nbatch * bsz)
+    # Evenly divide the data across the bsz batches.
+    data = data.view(bsz, -1).t().contiguous()
+    return data.to(device)
diff --git a/main.py b/main.py
index 43726edfdd5bb51812ee6725b04fc57b9738909d..5098339586e85bf40b8cc0e74f6d711be6b836fc 100644
--- a/main.py
+++ b/main.py
@@ -1,40 +1,35 @@
-import sys
-import os.path
-import numpy as np
 import torch
+import torch.onnx
 import torch.nn as nn
 
-from datetime import datetime
-
-import musicreader as mr
+import data
 import musicgenerator as mg
 
-# constants
-TRAINING_INPUT_FILENAME = "input.txt"
-TRAINING_INFOS_FILENAME = "training_infos.json"
-MODEL_PARAMETER_FILENAME = "model_parameters.pt"
+device = "cuda" if torch.cuda.is_available() else "cpu"
+
 
-# default parameters
-trainingFolder = os.path.abspath("./training_data/classical_music/test/")
-training = True
-outputPath = os.path.abspath("./output.mid")
+if __name__ == '__main__':
+    learning_rate = 0.01
+
+    ###############################################################################
+    # Prépare les données
+    ###############################################################################
+    corpus = data.Corpus('./')
+    train_data = data.batchify(corpus.train, 16, device)
+    val_data = data.batchify(corpus.valid, 16, device)
+
+    ###############################################################################
+    # Construit le modèle
+    ###############################################################################
+    ntokens = len(corpus.dictionary)
+    model = mg.MusicGenerator(vocab_size=ntokens, dim_model=8, num_head=4).to(device)
+    criterion = nn.NLLLoss()
+
+    ###############################################################################
+    # Entraîne le modèle
+    ###############################################################################
+    mg.train(model, criterion, ntokens, train_data, val_data, 4, learning_rate, 5, 500)
 
 
-def isTrainingNeeded() -> bool:
-    # TODO
-    return training
 
 
-if __name__ == '__main__':
-    if training or isTrainingNeeded():
-        minBpm, maxBpm = mr.readTrainingFolder(trainingFolder, TRAINING_INPUT_FILENAME, TRAINING_INFOS_FILENAME)
-        mg.train(
-            trainingInputFilePath=os.path.join(trainingFolder, TRAINING_INPUT_FILENAME),
-            modelParametersFilePath=os.path.join(trainingFolder, MODEL_PARAMETER_FILENAME),
-            dim_model=1
-        )
-    else:
-        minBpm, maxBpm = mr.getMinMaxBpm(trainingFolder, TRAINING_INFOS_FILENAME)
-
-    bpm = np.random.randint(minBpm, maxBpm + 1)
-    mg.generate(outputPath, bpm,  os.path.join(trainingFolder, MODEL_PARAMETER_FILENAME))
diff --git a/musicgenerator.py b/musicgenerator.py
index 5e212ba5411f7cf55fc07f54c476f4710df76d1b..d4f24a34a09b3fc2d19cda663c2fa42f000f05d5 100644
--- a/musicgenerator.py
+++ b/musicgenerator.py
@@ -1,84 +1,173 @@
 import math
+import time
+
 import torch
 import torch.nn as nn
-import torch.optim as optim
-
-device = "cuda" if torch.cuda.is_available() else "cpu"
+import torch.nn.functional as F
 
 
-# source : https://medium.com/p/c80afbc9ffb1/
 class PositionalEncoding(nn.Module):
-    def __init__(self, dim_model, dropout_p, max_len):
-        super().__init__()
-
-        # Info
-        self.dropout = nn.Dropout(dropout_p)
-
-        # Encoding - From formula
-        pos_encoding = torch.zeros(max_len, dim_model)
-        positions_list = torch.arange(0, max_len, dtype=torch.float).view(-1, 1)  # 0, 1, 2, 3, 4, 5
-        division_term = torch.exp(torch.arange(0, dim_model, 2).float() * (-math.log(10000.0)) / dim_model)  # 1000^(2i/dim_model)
-
-        # PE(pos, 2i) = sin(pos/1000^(2i/dim_model))
-        pos_encoding[:, 0::2] = torch.sin(positions_list * division_term)
-
-        # PE(pos, 2i + 1) = cos(pos/1000^(2i/dim_model))
-        pos_encoding[:, 1::2] = torch.cos(positions_list * division_term)
-
-        # Saving buffer (same as parameter without gradients needed)
-        pos_encoding = pos_encoding.unsqueeze(0).transpose(0, 1)
-        self.register_buffer("pos_encoding", pos_encoding)
-
-    def forward(self, token_embedding: torch.tensor) -> torch.tensor:
-        # Residual connection + pos encoding
-        return self.dropout(token_embedding + self.pos_encoding[:token_embedding.size(0), :])
-
-
-class MusicGenerator(nn.Module):
-    def __init__(self, dim_model: int, vocab_size: int, num_heads: int):
-        super().__init__()
-
+    def __init__(self, d_model, dropout=0.1, max_len=5000):
+        super(PositionalEncoding, self).__init__()
+        self.dropout = nn.Dropout(p=dropout)
+
+        pe = torch.zeros(max_len, d_model)
+        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
+        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
+        pe[:, 0::2] = torch.sin(position * div_term)
+        pe[:, 1::2] = torch.cos(position * div_term)
+        pe = pe.unsqueeze(0).transpose(0, 1)
+        self.register_buffer('pe', pe)
+
+    def forward(self, x):
+        x = x + self.pe[:x.size(0), :]
+        return self.dropout(x)
+
+
+class MusicGenerator(nn.Transformer):
+    def __init__(self, vocab_size: int, dim_model: int, num_head: int, dropout=0.5):
+        super(MusicGenerator, self).__init__(d_model=dim_model, nhead=num_head)
+        self.model_type = 'Transformer'
+        self.src_mask = None
+        self.pos_encoder = PositionalEncoding(dim_model, dropout)
+
+        self.input_emb = nn.Embedding(vocab_size, dim_model)
         self.dim_model = dim_model
+        self.decoder = nn.Linear(dim_model, vocab_size)
+
+        self.init_weights()
+
+    def _generate_square_subsequent_mask(self, sz):
+        return torch.log(torch.tril(torch.ones(sz, sz)))
+
+    def init_weights(self):
+        initrange = 0.1
+        nn.init.uniform_(self.input_emb.weight, -initrange, initrange)
+        nn.init.zeros_(self.decoder.bias)
+        nn.init.uniform_(self.decoder.weight, -initrange, initrange)
+
+    def forward(self, src, has_mask=True):
+        if has_mask:
+            device = src.device
+            if self.src_mask is None or self.src_mask.size(0) != len(src):
+                mask = self._generate_square_subsequent_mask(len(src)).to(device)
+                self.src_mask = mask
+        else:
+            self.src_mask = None
+
+        src = self.input_emb(src) * math.sqrt(self.dim_model)
+        src = self.pos_encoder(src)
+        output = self.encoder(src, mask=self.src_mask)
+        output = self.decoder(output)
+        return F.log_softmax(output, dim=-1)
+
+
+def get_batch(source: torch.Tensor, i: int, sequence_length: int) -> (torch.Tensor, torch.Tensor):
+    seq_len = min(sequence_length, len(source) - 1 - i)
+    data = source[i:i + seq_len]
+    target = source[i + 1:i + 1 + seq_len].view(-1)
+    return data, target
+
+
+def evaluate(model: MusicGenerator,
+             criterion: nn.NLLLoss,
+             data_source: torch.Tensor,
+             ntokens: int,
+             sequence_length: int,
+             ) -> float:
+    # Turn on evaluation mode which disables dropout.
+    model.eval()
+    total_loss = 0.
+    with torch.no_grad():
+        for i in range(0, data_source.size(0) - 1, sequence_length):
+            data, targets = get_batch(data_source, i, sequence_length)
+            output = model(data)
+            output = output.view(-1, ntokens)
+            total_loss += len(data) * criterion(output, targets).item()
+    return total_loss / (len(data_source) - 1)
+
+
+def train(model: MusicGenerator,
+          criterion: nn.NLLLoss,
+          ntokens: int,
+          train_data: torch.Tensor,
+          val_data: torch.Tensor,
+          sequence_length: int,
+          lr: float,
+          epochs: int,
+          log_interval: int
+          ) -> None:
+    # Loop over epochs.
+    best_val_loss = None
+
+    # At any point you can hit Ctrl + C to break out of training early.
+    try:
+        for epoch in range(1, epochs + 1):
+            epoch_start_time = time.time()
+            __train(model, criterion, ntokens, train_data, sequence_length, lr, epoch, log_interval)
+            val_loss = evaluate(model, criterion, val_data, ntokens, sequence_length)
+            print('-' * 89)
+            print('| end of epoch {:3d} | time: {:5.2f}s | valid loss {:5.2f} | '
+                  'valid ppl {:8.2f}'.format(epoch, (time.time() - epoch_start_time),
+                                             val_loss, math.exp(val_loss)))
+            print('-' * 89)
+            # Save the model if the validation loss is the best we've seen so far.
+            if not best_val_loss or val_loss < best_val_loss:
+                # with open(args.save, 'wb') as f:
+                #     torch.save(model, f)
+                best_val_loss = val_loss
+            else:
+                # Anneal the learning rate if no improvement has been seen in the validation dataset.
+                lr /= 4.0
+    except KeyboardInterrupt:
+        print('-' * 89)
+        print('Exiting from training early')
+
+
+def __train(model: MusicGenerator,
+            criterion: nn.NLLLoss,
+            ntokens: int,
+            train_data: torch.Tensor,
+            sequence_length: int,
+            lr: float,
+            epoch: int,
+            log_interval: int
+            ) -> None:
+    model.train()
+    total_loss = 0
+    start_time = time.time()
+
+    for batch, i in enumerate(range(0, train_data.size(0) - 1, sequence_length)):
+        data, targets = get_batch(train_data, i, sequence_length)
+        # Starting each batch, we detach the hidden state from how it was previously produced.
+        # If we didn't, the model would try backpropagating all the way to start of the dataset.
+        model.zero_grad()
+        output = model(data)
+        output = output.view(-1, ntokens)
+        loss = criterion(output, targets)
+        loss.backward()
 
-        # LAYERS
-        self.positional_encoder_layer = PositionalEncoding(
-            dim_model=dim_model, dropout_p=0.1, max_len=5000
-        )
-        self.embedding_layer = nn.Embedding(vocab_size, dim_model)
-        self.transformer_layer = nn.Transformer(d_model=dim_model, nhead=num_heads)
-        self.out = nn.Linear(dim_model, vocab_size)
-
-    def forward(self, src, tgt):
-        # Src size must be (batch_size, src sequence length)
-        # Tgt size must be (batch_size, tgt sequence length)
-
-        # Embedding + positional encoding - Out size = (batch_size, sequence length, dim_model)
-        src = self.embedding_layer(src) * math.sqrt(self.dim_model)
-        tgt = self.embedding_layer(tgt) * math.sqrt(self.dim_model)
-        src = self.positional_encoder_layer(src)
-        tgt = self.positional_encoder_layer(tgt)
-
-        # we permute to obtain size (sequence length, batch_size, dim_model),
-        src = src.permute(1, 0, 2)
-        tgt = tgt.permute(1, 0, 2)
-
-        # Transformer blocks - Out size = (sequence length, batch_size, num_tokens)
-        transformer_out = self.transformer_layer(src, tgt)
-        out = self.out(transformer_out)
-
-        return out
+        # `clip_grad_norm` helps prevent the exploding gradient problem in RNNs / LSTMs.
+        clip = 0.25
+        torch.nn.utils.clip_grad_norm_(model.parameters(), clip)
+        for p in model.parameters():
+            p.data.add_(p.grad, alpha=-lr)
 
+        total_loss += loss.item()
 
-def get_random_batch(data, block_size=256, batch_size=64):
-    # generate a small batch of data of inputs x and targets y
-    ix = torch.randint(len(data) - block_size, (batch_size,))
-    x = torch.stack([data[i:i + block_size] for i in ix])
-    y = torch.stack([data[i + 1:i + block_size + 1] for i in ix])
-    x, y = x.to(device), y.to(device)
-    return x, y
+        if batch % log_interval == 0 and batch > 0:
+            cur_loss = total_loss / log_interval
+            elapsed = time.time() - start_time
+            print('| epoch {:3d} | {:5d}/{:5d} batches | lr {:02.2f} | ms/batch {:5.2f} | '
+                  'loss {:5.2f} | ppl {:8.2f}'.format(epoch, batch, len(train_data) // sequence_length, lr, elapsed * 1000 / log_interval, cur_loss,
+                                                      math.exp(cur_loss)))
+            total_loss = 0
+            start_time = time.time()
+        # if args.dry_run:
+        #     break
 
 
-def train(trainingInputFilePath: str, modelParametersFilePath: str, dim_model: int):
+def train_old(trainingInputFilePath: str, modelParametersFilePath: str, dim_model: int):
     print(f"training from {trainingInputFilePath}")
 
     # code temp.
@@ -92,7 +181,7 @@ def train(trainingInputFilePath: str, modelParametersFilePath: str, dim_model: i
                 break
             vocab.add(char)
 
-    n_vocab = len(vocab)
+    ntokens = len(vocab)
     stoi = {ch: i for i, ch in enumerate(vocab)}
     itos = {i: ch for i, ch in enumerate(vocab)}
     encode = lambda s: [stoi[c] for c in s]  # encoder: take a string, output a list of integers
@@ -104,43 +193,17 @@ def train(trainingInputFilePath: str, modelParametersFilePath: str, dim_model: i
     train_data = data[:n]
     val_data = data[n:]
 
-    print(f"n_vocab = {n_vocab}, dim_model = {dim_model}")
+    print(f"n_vocab = {ntokens}, dim_model = {dim_model}")
 
     # IMPORTANT : dim_model doit être divisible par num_heads
-    mg = MusicGenerator(dim_model=dim_model, vocab_size=n_vocab, num_heads=1).to(device)
-
-    opt = torch.optim.SGD(mg.parameters(), lr=0.01)
-    loss_fn = nn.CrossEntropyLoss()
-
-    mg.train()
-    total_loss = 0
-
-    for i in range(10):
-        print(i)
-        X, y = get_random_batch(train_data)
-        X, y = torch.tensor(X).to(device), torch.tensor(y).to(device)
-
-        # Now we shift the tgt by one so with the <SOS> we predict the token at pos 1
-        y_input = y[:, :-1]
-        y_expected = y[:, 1:]
-
-        # Standard training except we pass in y_input and tgt_mask
-        pred = mg(X, y_input)
-
-        # Permute pred to have batch size first again
-        pred = pred.permute(1, 2, 0)
-        loss = loss_fn(pred, y_expected)
-
-        opt.zero_grad()
-        loss.backward()
-        opt.step()
-
-        total_loss += loss.detach().item()
-
-    with torch.no_grad():
-        print(mg())
-    # context = torch.zeros((1, 1), dtype=torch.long, device=device)
-    # print(decode(mg.generate(context, max_new_tokens=500)[0].tolist()))
+    # model = MusicGenerator(vocab_size=ntokens, dim_model=dim_model, num_head=1).to(device)
+    #
+    # model.train()
+    total_loss = 0.
+    start_time = time.time()
+    sequence_length = 2
+    # for batch, i in enumerate(range(0, n - 1, sequence_length)):
+    #     data, targets = get_batch(train_data, i)
 
 
 def generate(outputFilePath: str, bpm: int, modelParametersFilePath: str):
diff --git a/musicreader.py b/musicreader.py
deleted file mode 100644
index c05ddd3ce66ec20ee846aaa7bb9f3d70e1f17852..0000000000000000000000000000000000000000
--- a/musicreader.py
+++ /dev/null
@@ -1,51 +0,0 @@
-import json
-import os
-import torch
-
-import music21
-
-
-def readTrainingFolder(folderPath: str, trainingTextFilename: str, trainingInfoJsonFilename: str) -> (int, int):
-    """
-    Read all midi files and save it into a text file that can be read for the training process.
-    Also save some info about the training data in trainingInfoJsonFilename.
-    :param folderPath: path to folder containing all training midi files to read
-    :param trainingInfoJsonFilename:
-    :param trainingTextFilename:
-    :return: min and max bpm of the training data
-    """
-    print(f"read all midi files in {folderPath}")
-    print(f"save content in {os.path.join(folderPath, trainingTextFilename)}")
-    print(f"save infos in {os.path.join(folderPath, trainingInfoJsonFilename)}")
-    return 0, 0
-
-
-def getMinMaxBpm(folderPath: str, trainingInfoJsonFilename: str) -> (int, int):
-    """
-    ASSUMING THAT THE TRAINING DATA HAS ALREADY BEEN READ ONCE.
-    Read the trainingInfoJsonFilename file and return the min/max bpm.
-    :param folderPath: path to folder containing all training midi files
-    :param trainingInfoJsonFilename:
-    :return: min and max bpm of the training data
-    """
-    print(f"read infos in {os.path.join(folderPath, trainingInfoJsonFilename)}")
-    return 0, 0
-
-
-# en fait il y a pas besoin de ça car pytorch optimize aussi les données dans les nn.Embedding
-# mais cette fonction pourrait être utile si on voudrait un embedding plus spécifique
-# def createEmbedding(embeddingFilePath: str, trainingTextFilePath: str):
-#     print(f"read {trainingTextFilePath}")
-#     print(f"create embedding {embeddingFilePath}")
-#
-#     # example code
-#     vocab = set()
-#     with open(trainingTextFilePath, 'r') as file:
-#         while True:
-#             char = file.read(1)
-#             if not char:
-#                 break
-#             vocab.add(char)
-#     embedding = {c: [i] for i, c in enumerate(sorted(list(vocab)))}
-#     print(embedding)
-#     torch.save(embedding, embeddingFilePath)