Skip to content
Snippets Groups Projects
Commit 4409bf57 authored by lucien.noel's avatar lucien.noel
Browse files

ajout d'un étape de test lors de l'entrainement

parent 1ee83c34
Branches
No related tags found
No related merge requests found
......@@ -16,11 +16,11 @@ if __name__ == '__main__':
learning_rate = 20
batch_size = 16
split_train_test_valid = (0.8, 0.1, 0.1)
model_path = os.path.join(data_path, 'model.pt')
model_path = None # os.path.join(data_path, 'model.pt')
sequence_length = 32
epochs = 100
nb_log_epoch = 5
nb_words = 200
nb_words = 100
temperature = 1.0
print("device :", device)
......@@ -30,6 +30,7 @@ if __name__ == '__main__':
###############################################################################
corpus = data.Corpus(data_path, model_path is not None, split_train_test_valid)
train_data = data.batchify(corpus.train, batch_size, device)
test_data = data.batchify(corpus.test, batch_size, device)
val_data = data.batchify(corpus.valid, batch_size, device)
###############################################################################
......@@ -53,6 +54,7 @@ if __name__ == '__main__':
criterion=criterion,
ntokens=ntokens,
train_data=train_data,
test_data=test_data,
val_data=val_data,
sequence_length=sequence_length,
lr=learning_rate,
......
......@@ -94,6 +94,7 @@ def train(model: MusicGenerator,
criterion: nn.NLLLoss,
ntokens: int,
train_data: torch.Tensor,
test_data: torch.Tensor,
val_data: torch.Tensor,
sequence_length: int,
lr: float,
......@@ -123,6 +124,12 @@ def train(model: MusicGenerator,
else:
# Anneal the learning rate if no improvement has been seen in the validation dataset.
lr /= 4.0
test_loss = evaluate(model, criterion, test_data, ntokens, sequence_length)
print('=' * 89)
print('| End of training | test loss {:5.2f} | test ppl {:8.2f}'.format(
test_loss, math.exp(test_loss)))
print('=' * 89)
except KeyboardInterrupt:
print('-' * 89)
......@@ -152,7 +159,6 @@ def __train(model: MusicGenerator,
loss = criterion(output, targets)
loss.backward()
# `clip_grad_norm` helps prevent the exploding gradient problem in RNNs / LSTMs.
clip = 0.25
torch.nn.utils.clip_grad_norm_(model.parameters(), clip)
for p in model.parameters():
......
No preview for this file type
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment