diff --git a/data.py b/data.py index 18707d80dd8c159e39d75f10d2101164cf00dcd0..cc80c57a093474d65e3b666dd8a3142f811ace88 100644 --- a/data.py +++ b/data.py @@ -33,7 +33,7 @@ class Corpus(object): # Add words to the dictionary with open(path, 'r', encoding="utf8") as f: for line in f: - words = line.split() + ['<eos>'] + words = line.split() # + ['<eos>'] for word in words: self.dictionary.add_word(word) @@ -41,7 +41,7 @@ class Corpus(object): with open(path, 'r', encoding="utf8") as f: idss = [] for line in f: - words = line.split() + ['<eos>'] + words = line.split() # + ['<eos>'] ids = [] for word in words: ids.append(self.dictionary.word2idx[word]) diff --git a/main.py b/main.py index 5098339586e85bf40b8cc0e74f6d711be6b836fc..02ee152709f3eec47c992b86aff13eb6bc3952e7 100644 --- a/main.py +++ b/main.py @@ -7,9 +7,9 @@ import musicgenerator as mg device = "cuda" if torch.cuda.is_available() else "cpu" - if __name__ == '__main__': learning_rate = 0.01 + print("device :",device) ############################################################################### # Prépare les données @@ -28,8 +28,34 @@ if __name__ == '__main__': ############################################################################### # Entraîne le modèle ############################################################################### - mg.train(model, criterion, ntokens, train_data, val_data, 4, learning_rate, 5, 500) - - - + sequence_length = 4 + nb_log_epoch = 10 + mg.train( + model=model, + criterion=criterion, + ntokens=ntokens, + train_data=train_data, + val_data=val_data, + sequence_length=4, + lr=learning_rate, + epochs=5, + log_interval=(len(train_data)//sequence_length)//nb_log_epoch + ) + ############################################################################### + # Génère une mélodie + ############################################################################### + nb_words = 100 + temperature = 1.0 + model.eval() + input = torch.randint(ntokens, (1, 1), dtype=torch.long).to(device) + with torch.no_grad(): + for i in range(nb_words): + output = model(input, False) + word_weights = output[-1].squeeze().div(temperature).exp().cpu() + word_idx = torch.multinomial(word_weights, 1)[0] + word_tensor = torch.Tensor([[word_idx]]).long().to(device) + input = torch.cat([input, word_tensor], 0) + + word = corpus.dictionary.idx2word[word_idx] + print(word + ('\n' if i % 20 == 19 else ' '))