diff --git a/main.py b/main.py index 8e1ebb305ca76ec337b7a5b516af98fb940b5259..2939f46a1b7b8c82c9c366e1f0115613c75a8367 100644 --- a/main.py +++ b/main.py @@ -25,7 +25,7 @@ if __name__ == '__main__': dropout = 0.2 epochs = 100 nb_log_epoch = 5 - nb_words = 200 + nb_words = 100 temperature = 1.0 print("device :", device) @@ -35,6 +35,7 @@ if __name__ == '__main__': ############################################################################### corpus = data.Corpus(data_path, model_path is not None, split_train_test_valid) train_data = data.batchify(corpus.train, batch_size, device) + test_data = data.batchify(corpus.test, batch_size, device) val_data = data.batchify(corpus.valid, batch_size, device) ############################################################################### @@ -58,6 +59,7 @@ if __name__ == '__main__': criterion=criterion, ntokens=ntokens, train_data=train_data, + test_data=test_data, val_data=val_data, sequence_length=sequence_length, lr=learning_rate, diff --git a/musicgenerator.py b/musicgenerator.py index cb340f5381662a26df2015fcd747391aaac7af4b..c684346c06bd261df161aa10111f919f1134c56b 100644 --- a/musicgenerator.py +++ b/musicgenerator.py @@ -94,6 +94,7 @@ def train(model: MusicGenerator, criterion: nn.NLLLoss, ntokens: int, train_data: torch.Tensor, + test_data: torch.Tensor, val_data: torch.Tensor, sequence_length: int, lr: float, @@ -123,6 +124,12 @@ def train(model: MusicGenerator, else: # Anneal the learning rate if no improvement has been seen in the validation dataset. lr /= 4.0 + test_loss = evaluate(model, criterion, test_data, ntokens, sequence_length) + print('=' * 89) + print('| End of training | test loss {:5.2f} | test ppl {:8.2f}'.format( + test_loss, math.exp(test_loss))) + print('=' * 89) + except KeyboardInterrupt: print('-' * 89) @@ -152,7 +159,6 @@ def __train(model: MusicGenerator, loss = criterion(output, targets) loss.backward() - # `clip_grad_norm` helps prevent the exploding gradient problem in RNNs / LSTMs. clip = 0.25 torch.nn.utils.clip_grad_norm_(model.parameters(), clip) for p in model.parameters():