Skip to content
Snippets Groups Projects
Commit 69e2b754 authored by lucien.noel's avatar lucien.noel
Browse files

debut des tests pour trouver les meilleurs paramètres

parent 1ee83c34
Branches
No related tags found
No related merge requests found
......@@ -16,8 +16,13 @@ if __name__ == '__main__':
learning_rate = 20
batch_size = 16
split_train_test_valid = (0.8, 0.1, 0.1)
model_path = os.path.join(data_path, 'model.pt')
model_path = None # os.path.join(data_path, 'model.pt')
sequence_length = 32
dim_model = 128
num_head = 8
num_layers = 2
num_hid = 200
dropout = 0.2
epochs = 100
nb_log_epoch = 5
nb_words = 200
......@@ -42,7 +47,7 @@ if __name__ == '__main__':
with open(model_path, 'rb') as f:
model = torch.load(f, map_location=device)
else:
model = mg.MusicGenerator(vocab_size=ntokens, dim_model=512, num_head=8).to(device)
model = mg.MusicGenerator(vocab_size=ntokens, dim_model=dim_model, num_head=num_head, nhid=num_hid, nlayers=num_layers, dropout=dropout).to(device)
criterion = nn.NLLLoss()
###############################################################################
......
......@@ -28,8 +28,8 @@ class PositionalEncoding(nn.Module):
class MusicGenerator(nn.Transformer):
def __init__(self, vocab_size: int, dim_model: int, num_head: int, dropout=0.5):
super(MusicGenerator, self).__init__(d_model=dim_model, nhead=num_head)
def __init__(self, vocab_size: int, dim_model: int, num_head: int, nhid: int, nlayers: int, dropout: float):
super(MusicGenerator, self).__init__(d_model=dim_model, nhead=num_head, dim_feedforward=nhid, num_encoder_layers=nlayers)
self.model_type = 'Transformer'
self.src_mask = None
self.pos_encoder = PositionalEncoding(dim_model, dropout)
......
test.py 0 → 100644
import argparse
import random
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--lr', type=float, default=20, help='initial learning rate')
parser.add_argument('--epochs', type=int, default=100, help='upper epoch limit')
parser.add_argument('--batch_size', type=int, default=16, help='batch size')
parser.add_argument('--sequence_length', type=int, default=32, help='sequence length')
parser.add_argument('--dimension_model', type=int, default=128, help='size of word embeddings')
parser.add_argument('--nhead', type=int, default=4, help='the number of heads in the encoder/decoder of the transformer model')
parser.add_argument('--dropout', type=float, default=0.2, help='dropout applied to layers (0 = no dropout)')
parser.add_argument('--nhid', type=int, default=200, help='number of hidden units per layer')
parser.add_argument('--nlayers', type=int, default=2, help='number of layers')
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment