diff --git a/main.py b/main.py index a7dbe88212b25872508274bd616a5d8cf9ce34a6..68808b5693d998d019f3904bf89abe3ec04cd451 100644 --- a/main.py +++ b/main.py @@ -16,11 +16,11 @@ if __name__ == '__main__': learning_rate = 20 batch_size = 16 split_train_test_valid = (0.8, 0.1, 0.1) - model_path = os.path.join(data_path, 'model.pt') + model_path = None # os.path.join(data_path, 'model.pt') sequence_length = 32 epochs = 100 nb_log_epoch = 5 - nb_words = 200 + nb_words = 100 temperature = 1.0 print("device :", device) @@ -30,6 +30,7 @@ if __name__ == '__main__': ############################################################################### corpus = data.Corpus(data_path, model_path is not None, split_train_test_valid) train_data = data.batchify(corpus.train, batch_size, device) + test_data = data.batchify(corpus.test, batch_size, device) val_data = data.batchify(corpus.valid, batch_size, device) ############################################################################### @@ -53,6 +54,7 @@ if __name__ == '__main__': criterion=criterion, ntokens=ntokens, train_data=train_data, + test_data=test_data, val_data=val_data, sequence_length=sequence_length, lr=learning_rate, diff --git a/musicgenerator.py b/musicgenerator.py index 9d8cde017af7c394e7368e8e919ee71e5e20ecfc..60861292045e8def82674deb83055e6c7d9eb4ce 100644 --- a/musicgenerator.py +++ b/musicgenerator.py @@ -94,6 +94,7 @@ def train(model: MusicGenerator, criterion: nn.NLLLoss, ntokens: int, train_data: torch.Tensor, + test_data: torch.Tensor, val_data: torch.Tensor, sequence_length: int, lr: float, @@ -123,6 +124,12 @@ def train(model: MusicGenerator, else: # Anneal the learning rate if no improvement has been seen in the validation dataset. lr /= 4.0 + test_loss = evaluate(model, criterion, test_data, ntokens, sequence_length) + print('=' * 89) + print('| End of training | test loss {:5.2f} | test ppl {:8.2f}'.format( + test_loss, math.exp(test_loss))) + print('=' * 89) + except KeyboardInterrupt: print('-' * 89) @@ -152,7 +159,6 @@ def __train(model: MusicGenerator, loss = criterion(output, targets) loss.backward() - # `clip_grad_norm` helps prevent the exploding gradient problem in RNNs / LSTMs. clip = 0.25 torch.nn.utils.clip_grad_norm_(model.parameters(), clip) for p in model.parameters(): diff --git a/training_data/classical_music/output.mid b/training_data/classical_music/output.mid index 3f6efc376cf5c74c47cba84b48c919481dbd2cb3..f426b0b79e84656bd9c37d625b0e1a54512bd457 100644 Binary files a/training_data/classical_music/output.mid and b/training_data/classical_music/output.mid differ