Skip to content
Snippets Groups Projects
Commit 9e0de491 authored by lucien.noel's avatar lucien.noel
Browse files

adapte la signature de temps selon la signature de temps la plus vue dans les...

adapte la signature de temps selon la signature de temps la plus vue dans les données d'entraînement
parent 42fcc289
No related branches found
No related tags found
No related merge requests found
...@@ -16,7 +16,7 @@ if __name__ == '__main__': ...@@ -16,7 +16,7 @@ if __name__ == '__main__':
parser.add_argument('--batch_size', type=int, default=16, help='batch size') parser.add_argument('--batch_size', type=int, default=16, help='batch size')
parser.add_argument('--sequence_length', type=int, default=32, help='sequence length') parser.add_argument('--sequence_length', type=int, default=32, help='sequence length')
parser.add_argument('--dimension_model', type=int, default=128, help='size of word embeddings') parser.add_argument('--dimension_model', type=int, default=128, help='size of word embeddings')
parser.add_argument('--nhead', type=int, default=4, help='the number of heads in the encoder/decoder of the transformer model') parser.add_argument('--nhead', type=int, default=4, help='the number of heads in the encoder/decoder of the transformer model')
parser.add_argument('--dropout', type=float, default=0.2, help='dropout applied to layers (0 = no dropout)') parser.add_argument('--dropout', type=float, default=0.2, help='dropout applied to layers (0 = no dropout)')
parser.add_argument('--nhid', type=int, default=200, help='number of hidden units per layer') parser.add_argument('--nhid', type=int, default=200, help='number of hidden units per layer')
parser.add_argument('--nlayers', type=int, default=2, help='number of layers') parser.add_argument('--nlayers', type=int, default=2, help='number of layers')
...@@ -60,7 +60,8 @@ if __name__ == '__main__': ...@@ -60,7 +60,8 @@ if __name__ == '__main__':
with open(model_path, 'rb') as f: with open(model_path, 'rb') as f:
model = torch.load(f, map_location=device) model = torch.load(f, map_location=device)
else: else:
model = mg.MusicGenerator(vocab_size=ntokens, dim_model=dim_model, num_head=num_head, nhid=num_hid, nlayers=num_layers, dropout=dropout).to(device) model = mg.MusicGenerator(vocab_size=ntokens, dim_model=dim_model, num_head=num_head, nhid=num_hid, nlayers=num_layers, dropout=dropout).to(
device)
criterion = nn.NLLLoss() criterion = nn.NLLLoss()
############################################################################### ###############################################################################
...@@ -83,18 +84,20 @@ if __name__ == '__main__': ...@@ -83,18 +84,20 @@ if __name__ == '__main__':
############################################################################### ###############################################################################
# Génère une mélodie # Génère une mélodie
############################################################################### ###############################################################################
# if corpus.parser.min_bpm > corpus.parser.max_bpm: if corpus.parser.min_bpm > corpus.parser.max_bpm:
# min_temp = corpus.parser.min_bpm min_temp = corpus.parser.min_bpm
# corpus.parser.min_bpm = corpus.parser.max_bpm corpus.parser.min_bpm = corpus.parser.max_bpm
# corpus.parser.max_bpm = min_temp corpus.parser.max_bpm = min_temp
# bpm = np.random.randint(corpus.parser.min_bpm, corpus.parser.max_bpm + 1) bpm = np.random.randint(corpus.parser.min_bpm, corpus.parser.max_bpm + 1)
# mg.generate( ts = corpus.parser.ts
# model=model, mg.generate(
# ntokens=ntokens, model=model,
# corpus=corpus, ntokens=ntokens,
# nb_words=nb_words, corpus=corpus,
# temperature=temperature, nb_words=nb_words,
# device=device, temperature=temperature,
# data_path=data_path, device=device,
# bpm=bpm data_path=data_path,
# ) bpm=bpm,
ts=ts
)
...@@ -186,7 +186,8 @@ def generate(model: MusicGenerator, ...@@ -186,7 +186,8 @@ def generate(model: MusicGenerator,
temperature: float, temperature: float,
corpus: data.Corpus, corpus: data.Corpus,
data_path: str, data_path: str,
bpm: int bpm: int,
ts: str
) -> None: ) -> None:
model.eval() model.eval()
input = torch.randint(ntokens, (1, 1), dtype=torch.long).to(device) input = torch.randint(ntokens, (1, 1), dtype=torch.long).to(device)
...@@ -202,4 +203,4 @@ def generate(model: MusicGenerator, ...@@ -202,4 +203,4 @@ def generate(model: MusicGenerator,
word = corpus.dictionary.idx2word[word_idx] word = corpus.dictionary.idx2word[word_idx]
print(word + ' ', end='') print(word + ' ', end='')
text += word + corpus.parser.word_splitter text += word + corpus.parser.word_splitter
corpus.parser.text_to_midi(os.path.join(data_path, 'output.mid'), text, bpm) corpus.parser.text_to_midi(os.path.join(data_path, 'output.mid'), text, bpm, ts)
...@@ -3,13 +3,15 @@ import re ...@@ -3,13 +3,15 @@ import re
from typing import Tuple from typing import Tuple
import numpy as np import numpy as np
from music21 import converter, tempo, note, chord, stream from music21 import converter, tempo, note, chord, stream, meter
from collections import Counter
class MusicParser: class MusicParser:
def __init__(self, midi_folder_path: str, split_train_test_valid: Tuple[float, float, float]): def __init__(self, midi_folder_path: str, split_train_test_valid: Tuple[float, float, float]):
assert sum(split_train_test_valid) == 1.0, "ERROR : Sum of split_train_test_valid must be 1" assert sum(split_train_test_valid) == 1.0, "ERROR : Sum of split_train_test_valid must be 1"
self.word_splitter = '&' self.word_splitter = '&'
self.ts = '4/4'
self.midi_folder = os.path.normpath(midi_folder_path) self.midi_folder = os.path.normpath(midi_folder_path)
self.train_split = split_train_test_valid[0] self.train_split = split_train_test_valid[0]
self.test_split = split_train_test_valid[1] self.test_split = split_train_test_valid[1]
...@@ -19,12 +21,18 @@ class MusicParser: ...@@ -19,12 +21,18 @@ class MusicParser:
def midi_to_text(self) -> None: def midi_to_text(self) -> None:
textes = [] textes = []
tss = []
for file in os.listdir(self.midi_folder): for file in os.listdir(self.midi_folder):
if file.endswith(".mid") and file != 'output.mid': if file.endswith(".mid") and file != 'output.mid':
bpm = 0 bpm = 0
nb_bpm_part = 0 nb_bpm_part = 0
print(file) print(file)
score = converter.parse(os.path.join(self.midi_folder, file)) score = converter.parse(os.path.join(self.midi_folder, file))
# print(score.show('text'))
tsss = []
for part in score.parts:
tsss.append(part.measure(1).timeSignature)
tss.append(Counter(tsss).most_common(1)[0][0])
score = score.chordify() score = score.chordify()
notes = [] notes = []
for el in score.recurse(): for el in score.recurse():
...@@ -52,6 +60,8 @@ class MusicParser: ...@@ -52,6 +60,8 @@ class MusicParser:
self.min_bpm = bpm self.min_bpm = bpm
textes.append(self.word_splitter.join(notes)) textes.append(self.word_splitter.join(notes))
ts = Counter(tss).most_common(1)[0][0]
self.ts = f"{ts.numerator}/{ts.denominator}"
nb_music = len(textes) nb_music = len(textes)
nb_train = int(nb_music * self.train_split) nb_train = int(nb_music * self.train_split)
nb_test = int(nb_music * self.test_split) nb_test = int(nb_music * self.test_split)
...@@ -69,13 +79,16 @@ class MusicParser: ...@@ -69,13 +79,16 @@ class MusicParser:
for music in data: for music in data:
f.write(f"{music}{self.word_splitter}") f.write(f"{music}{self.word_splitter}")
def text_to_midi(self, midi_file_path: str, text: str, bpm: int) -> None: def text_to_midi(self, midi_file_path: str, text: str, bpm: int, ts: str) -> None:
# Create a Stream object to hold the notes and rests # Create a Stream object to hold the notes and rests
midi_stream = stream.Stream() midi_stream = stream.Stream()
mm = tempo.MetronomeMark(number=bpm) mm = tempo.MetronomeMark(number=bpm)
midi_stream.insert(0, mm) midi_stream.insert(0, mm)
ts = meter.TimeSignature(self.ts)
midi_stream.insert(0, ts)
# Split the sequential text into individual tokens # Split the sequential text into individual tokens
tokens = text.split('&') tokens = text.split('&')
...@@ -96,7 +109,7 @@ class MusicParser: ...@@ -96,7 +109,7 @@ class MusicParser:
midi_stream.append(r) midi_stream.append(r)
# Convert the Stream to a MIDI file and write it to disk # Convert the Stream to a MIDI file and write it to disk
# midi_stream.show('text') midi_stream.show('text')
midi_stream.write('midi', fp=midi_file_path) midi_stream.write('midi', fp=midi_file_path)
def __strQuarterLength_to_float(self, quarterLength: str) -> float: def __strQuarterLength_to_float(self, quarterLength: str) -> float:
......
No preview for this file type
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment