Skip to content
Snippets Groups Projects
Commit d54f68da authored by lucien.noel's avatar lucien.noel
Browse files

merge

parents 02d4e0c4 4409bf57
Branches
No related tags found
Loading
......@@ -25,7 +25,7 @@ if __name__ == '__main__':
dropout = 0.2
epochs = 100
nb_log_epoch = 5
nb_words = 200
nb_words = 100
temperature = 1.0
print("device :", device)
......@@ -35,6 +35,7 @@ if __name__ == '__main__':
###############################################################################
corpus = data.Corpus(data_path, model_path is not None, split_train_test_valid)
train_data = data.batchify(corpus.train, batch_size, device)
test_data = data.batchify(corpus.test, batch_size, device)
val_data = data.batchify(corpus.valid, batch_size, device)
###############################################################################
......@@ -58,6 +59,7 @@ if __name__ == '__main__':
criterion=criterion,
ntokens=ntokens,
train_data=train_data,
test_data=test_data,
val_data=val_data,
sequence_length=sequence_length,
lr=learning_rate,
......
......@@ -94,6 +94,7 @@ def train(model: MusicGenerator,
criterion: nn.NLLLoss,
ntokens: int,
train_data: torch.Tensor,
test_data: torch.Tensor,
val_data: torch.Tensor,
sequence_length: int,
lr: float,
......@@ -123,6 +124,12 @@ def train(model: MusicGenerator,
else:
# Anneal the learning rate if no improvement has been seen in the validation dataset.
lr /= 4.0
test_loss = evaluate(model, criterion, test_data, ntokens, sequence_length)
print('=' * 89)
print('| End of training | test loss {:5.2f} | test ppl {:8.2f}'.format(
test_loss, math.exp(test_loss)))
print('=' * 89)
except KeyboardInterrupt:
print('-' * 89)
......@@ -152,7 +159,6 @@ def __train(model: MusicGenerator,
loss = criterion(output, targets)
loss.backward()
# `clip_grad_norm` helps prevent the exploding gradient problem in RNNs / LSTMs.
clip = 0.25
torch.nn.utils.clip_grad_norm_(model.parameters(), clip)
for p in model.parameters():
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment