From 7c7097b52489e607da730bf51b8d1a03e21bc492 Mon Sep 17 00:00:00 2001 From: "lucien.noel" <noellucien2001@gmail.com> Date: Wed, 26 Jun 2024 17:53:23 +0200 Subject: [PATCH] =?UTF-8?q?refactoring=20Transformer=20:=20g=C3=A9n=C3=A9r?= =?UTF-8?q?er=20un=20text=20d'exemple?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- data.py | 4 ++-- main.py | 36 +++++++++++++++++++++++++++++++----- 2 files changed, 33 insertions(+), 7 deletions(-) diff --git a/data.py b/data.py index 18707d8..cc80c57 100644 --- a/data.py +++ b/data.py @@ -33,7 +33,7 @@ class Corpus(object): # Add words to the dictionary with open(path, 'r', encoding="utf8") as f: for line in f: - words = line.split() + ['<eos>'] + words = line.split() # + ['<eos>'] for word in words: self.dictionary.add_word(word) @@ -41,7 +41,7 @@ class Corpus(object): with open(path, 'r', encoding="utf8") as f: idss = [] for line in f: - words = line.split() + ['<eos>'] + words = line.split() # + ['<eos>'] ids = [] for word in words: ids.append(self.dictionary.word2idx[word]) diff --git a/main.py b/main.py index 5098339..02ee152 100644 --- a/main.py +++ b/main.py @@ -7,9 +7,9 @@ import musicgenerator as mg device = "cuda" if torch.cuda.is_available() else "cpu" - if __name__ == '__main__': learning_rate = 0.01 + print("device :",device) ############################################################################### # Prépare les données @@ -28,8 +28,34 @@ if __name__ == '__main__': ############################################################################### # Entraîne le modèle ############################################################################### - mg.train(model, criterion, ntokens, train_data, val_data, 4, learning_rate, 5, 500) - - - + sequence_length = 4 + nb_log_epoch = 10 + mg.train( + model=model, + criterion=criterion, + ntokens=ntokens, + train_data=train_data, + val_data=val_data, + sequence_length=4, + lr=learning_rate, + epochs=5, + log_interval=(len(train_data)//sequence_length)//nb_log_epoch + ) + ############################################################################### + # Génère une mélodie + ############################################################################### + nb_words = 100 + temperature = 1.0 + model.eval() + input = torch.randint(ntokens, (1, 1), dtype=torch.long).to(device) + with torch.no_grad(): + for i in range(nb_words): + output = model(input, False) + word_weights = output[-1].squeeze().div(temperature).exp().cpu() + word_idx = torch.multinomial(word_weights, 1)[0] + word_tensor = torch.Tensor([[word_idx]]).long().to(device) + input = torch.cat([input, word_tensor], 0) + + word = corpus.dictionary.idx2word[word_idx] + print(word + ('\n' if i % 20 == 19 else ' ')) -- GitLab