Skip to content
Snippets Groups Projects
Commit fedeeb6f authored by LilDotTheGod's avatar LilDotTheGod
Browse files

Ajout de l'option pour sauvegarder/charger un model

parent c34cb4ae
No related branches found
No related tags found
No related merge requests found
.idea/ .idea/
.venv/ .venv/
__pycache__/ __pycache__/
*_old.py *_old.py
\ No newline at end of file model.pt
\ No newline at end of file
import os
import torch import torch
import torch.onnx # import torch.onnx
import torch.nn as nn import torch.nn as nn
import data import data
...@@ -8,13 +9,23 @@ import musicgenerator as mg ...@@ -8,13 +9,23 @@ import musicgenerator as mg
device = "cuda" if torch.cuda.is_available() else "cpu" device = "cuda" if torch.cuda.is_available() else "cpu"
if __name__ == '__main__': if __name__ == '__main__':
###############################################################################
# Paramètres
###############################################################################
learning_rate = 20 learning_rate = 20
batch_size = 16
model_path = None # os.path.normpath('./model.pt')
sequence_length = 32
epochs = 100
nb_log_epoch = 5
nb_words = 100
temperature = 1.0
print("device :", device) print("device :", device)
############################################################################### ###############################################################################
# Prépare les données # Prépare les données
############################################################################### ###############################################################################
batch_size = 16
corpus = data.Corpus('./') corpus = data.Corpus('./')
train_data = data.batchify(corpus.train, batch_size, device) train_data = data.batchify(corpus.train, batch_size, device)
val_data = data.batchify(corpus.valid, batch_size, device) val_data = data.batchify(corpus.valid, batch_size, device)
...@@ -23,32 +34,33 @@ if __name__ == '__main__': ...@@ -23,32 +34,33 @@ if __name__ == '__main__':
# Construit le modèle # Construit le modèle
############################################################################### ###############################################################################
ntokens = len(corpus.dictionary) ntokens = len(corpus.dictionary)
model = mg.MusicGenerator(vocab_size=ntokens, dim_model=512, num_head=8).to(device)
criterion = nn.NLLLoss()
############################################################################### if model_path:
# Entraîne le modèle print("Load model from", model_path)
############################################################################### with open(model_path, 'rb') as f:
sequence_length = 32 model = torch.load(f, map_location=device)
nb_log_epoch = 5 else:
epochs = 100 model = mg.MusicGenerator(vocab_size=ntokens, dim_model=512, num_head=8).to(device)
mg.train( criterion = nn.NLLLoss()
model=model,
criterion=criterion, ###############################################################################
ntokens=ntokens, # Entraîne le modèle
train_data=train_data, ###############################################################################
val_data=val_data, mg.train(
sequence_length=sequence_length, model=model,
lr=learning_rate, criterion=criterion,
epochs=epochs, ntokens=ntokens,
log_interval=(len(train_data) // sequence_length) // nb_log_epoch train_data=train_data,
) val_data=val_data,
sequence_length=sequence_length,
lr=learning_rate,
epochs=epochs,
log_interval=(len(train_data) // sequence_length) // nb_log_epoch
)
############################################################################### ###############################################################################
# Génère une mélodie # Génère une mélodie
############################################################################### ###############################################################################
nb_words = 100
temperature = 1.0
model.eval() model.eval()
input = torch.randint(ntokens, (1, 1), dtype=torch.long).to(device) input = torch.randint(ntokens, (1, 1), dtype=torch.long).to(device)
with torch.no_grad(): with torch.no_grad():
......
Source diff could not be displayed: it is too large. Options to address this: view the blob.
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment