diff --git a/main.py b/main.py
index a7dbe88212b25872508274bd616a5d8cf9ce34a6..8e1ebb305ca76ec337b7a5b516af98fb940b5259 100644
--- a/main.py
+++ b/main.py
@@ -16,8 +16,13 @@ if __name__ == '__main__':
     learning_rate = 20
     batch_size = 16
     split_train_test_valid = (0.8, 0.1, 0.1)
-    model_path = os.path.join(data_path, 'model.pt')
+    model_path = None  # os.path.join(data_path, 'model.pt')
     sequence_length = 32
+    dim_model = 128
+    num_head = 8
+    num_layers = 2
+    num_hid = 200
+    dropout = 0.2
     epochs = 100
     nb_log_epoch = 5
     nb_words = 200
@@ -42,7 +47,7 @@ if __name__ == '__main__':
         with open(model_path, 'rb') as f:
             model = torch.load(f, map_location=device)
     else:
-        model = mg.MusicGenerator(vocab_size=ntokens, dim_model=512, num_head=8).to(device)
+        model = mg.MusicGenerator(vocab_size=ntokens, dim_model=dim_model, num_head=num_head, nhid=num_hid, nlayers=num_layers, dropout=dropout).to(device)
         criterion = nn.NLLLoss()
 
         ###############################################################################
diff --git a/musicgenerator.py b/musicgenerator.py
index 9d8cde017af7c394e7368e8e919ee71e5e20ecfc..cb340f5381662a26df2015fcd747391aaac7af4b 100644
--- a/musicgenerator.py
+++ b/musicgenerator.py
@@ -28,8 +28,8 @@ class PositionalEncoding(nn.Module):
 
 
 class MusicGenerator(nn.Transformer):
-    def __init__(self, vocab_size: int, dim_model: int, num_head: int, dropout=0.5):
-        super(MusicGenerator, self).__init__(d_model=dim_model, nhead=num_head)
+    def __init__(self, vocab_size: int, dim_model: int, num_head: int, nhid: int, nlayers: int, dropout: float):
+        super(MusicGenerator, self).__init__(d_model=dim_model, nhead=num_head, dim_feedforward=nhid, num_encoder_layers=nlayers)
         self.model_type = 'Transformer'
         self.src_mask = None
         self.pos_encoder = PositionalEncoding(dim_model, dropout)
diff --git a/test.py b/test.py
new file mode 100644
index 0000000000000000000000000000000000000000..5f6a8e0917e5693e6ccb74518b175e62bf8efb1a
--- /dev/null
+++ b/test.py
@@ -0,0 +1,18 @@
+import argparse
+import random
+
+if __name__ == '__main__':
+    parser = argparse.ArgumentParser()
+
+    parser.add_argument('--lr', type=float, default=20, help='initial learning rate')
+    parser.add_argument('--epochs', type=int, default=100, help='upper epoch limit')
+    parser.add_argument('--batch_size', type=int, default=16, help='batch size')
+    parser.add_argument('--sequence_length', type=int, default=32, help='sequence length')
+    parser.add_argument('--dimension_model', type=int, default=128, help='size of word embeddings')
+
+    parser.add_argument('--nhead', type=int, default=4,  help='the number of heads in the encoder/decoder of the transformer model')
+    parser.add_argument('--dropout', type=float, default=0.2, help='dropout applied to layers (0 = no dropout)')
+    parser.add_argument('--nhid', type=int, default=200, help='number of hidden units per layer')
+    parser.add_argument('--nlayers', type=int, default=2, help='number of layers')
+
+