diff --git a/main.py b/main.py index a7dbe88212b25872508274bd616a5d8cf9ce34a6..8e1ebb305ca76ec337b7a5b516af98fb940b5259 100644 --- a/main.py +++ b/main.py @@ -16,8 +16,13 @@ if __name__ == '__main__': learning_rate = 20 batch_size = 16 split_train_test_valid = (0.8, 0.1, 0.1) - model_path = os.path.join(data_path, 'model.pt') + model_path = None # os.path.join(data_path, 'model.pt') sequence_length = 32 + dim_model = 128 + num_head = 8 + num_layers = 2 + num_hid = 200 + dropout = 0.2 epochs = 100 nb_log_epoch = 5 nb_words = 200 @@ -42,7 +47,7 @@ if __name__ == '__main__': with open(model_path, 'rb') as f: model = torch.load(f, map_location=device) else: - model = mg.MusicGenerator(vocab_size=ntokens, dim_model=512, num_head=8).to(device) + model = mg.MusicGenerator(vocab_size=ntokens, dim_model=dim_model, num_head=num_head, nhid=num_hid, nlayers=num_layers, dropout=dropout).to(device) criterion = nn.NLLLoss() ############################################################################### diff --git a/musicgenerator.py b/musicgenerator.py index 9d8cde017af7c394e7368e8e919ee71e5e20ecfc..cb340f5381662a26df2015fcd747391aaac7af4b 100644 --- a/musicgenerator.py +++ b/musicgenerator.py @@ -28,8 +28,8 @@ class PositionalEncoding(nn.Module): class MusicGenerator(nn.Transformer): - def __init__(self, vocab_size: int, dim_model: int, num_head: int, dropout=0.5): - super(MusicGenerator, self).__init__(d_model=dim_model, nhead=num_head) + def __init__(self, vocab_size: int, dim_model: int, num_head: int, nhid: int, nlayers: int, dropout: float): + super(MusicGenerator, self).__init__(d_model=dim_model, nhead=num_head, dim_feedforward=nhid, num_encoder_layers=nlayers) self.model_type = 'Transformer' self.src_mask = None self.pos_encoder = PositionalEncoding(dim_model, dropout) diff --git a/test.py b/test.py new file mode 100644 index 0000000000000000000000000000000000000000..5f6a8e0917e5693e6ccb74518b175e62bf8efb1a --- /dev/null +++ b/test.py @@ -0,0 +1,18 @@ +import argparse +import random + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + + parser.add_argument('--lr', type=float, default=20, help='initial learning rate') + parser.add_argument('--epochs', type=int, default=100, help='upper epoch limit') + parser.add_argument('--batch_size', type=int, default=16, help='batch size') + parser.add_argument('--sequence_length', type=int, default=32, help='sequence length') + parser.add_argument('--dimension_model', type=int, default=128, help='size of word embeddings') + + parser.add_argument('--nhead', type=int, default=4, help='the number of heads in the encoder/decoder of the transformer model') + parser.add_argument('--dropout', type=float, default=0.2, help='dropout applied to layers (0 = no dropout)') + parser.add_argument('--nhid', type=int, default=200, help='number of hidden units per layer') + parser.add_argument('--nlayers', type=int, default=2, help='number of layers') + +