Skip to content
Snippets Groups Projects
Commit 58a49768 authored by lucien.noel's avatar lucien.noel
Browse files

Refactoring Transformer (Work in progress)

parent c9666307
Branches
No related tags found
No related merge requests found
import sys
import os.path
import numpy as np
import torch
import torch.nn as nn
from datetime import datetime
import musicreader as mr
......@@ -9,10 +12,8 @@ import musicgenerator as mg
# constants
TRAINING_INPUT_FILENAME = "input.txt"
TRAINING_INFOS_FILENAME = "training_infos.json"
EMBEDDING_FILENAME = "embeddings.npy"
MODEL_PARAMETER_FILENAME = "model_parameters.pt"
# default parameters
trainingFolder = os.path.abspath("./training_data/classical_music/test/")
training = True
......@@ -25,11 +26,13 @@ def isTrainingNeeded() -> bool:
if __name__ == '__main__':
if training or isTrainingNeeded():
minBpm, maxBpm = mr.readTrainingFolder(trainingFolder, TRAINING_INPUT_FILENAME, TRAINING_INFOS_FILENAME)
mr.createEmbedding(os.path.join(trainingFolder, EMBEDDING_FILENAME), os.path.join(trainingFolder, TRAINING_INPUT_FILENAME))
mg.train(os.path.join(trainingFolder, TRAINING_INPUT_FILENAME), os.path.join(trainingFolder, EMBEDDING_FILENAME), os.path.join(trainingFolder, MODEL_PARAMETER_FILENAME))
mg.train(
trainingInputFilePath=os.path.join(trainingFolder, TRAINING_INPUT_FILENAME),
modelParametersFilePath=os.path.join(trainingFolder, MODEL_PARAMETER_FILENAME),
dim_model=1
)
else:
minBpm, maxBpm = mr.getMinMaxBpm(trainingFolder, TRAINING_INFOS_FILENAME)
......
def train(trainingInputFilePath: str, embeddingFilePath: str, modelParametersFilePath: str):
import math
import torch
import torch.nn as nn
import torch.optim as optim
device = "cuda" if torch.cuda.is_available() else "cpu"
# source : https://medium.com/p/c80afbc9ffb1/
class PositionalEncoding(nn.Module):
def __init__(self, dim_model, dropout_p, max_len):
super().__init__()
# Info
self.dropout = nn.Dropout(dropout_p)
# Encoding - From formula
pos_encoding = torch.zeros(max_len, dim_model)
positions_list = torch.arange(0, max_len, dtype=torch.float).view(-1, 1) # 0, 1, 2, 3, 4, 5
division_term = torch.exp(torch.arange(0, dim_model, 2).float() * (-math.log(10000.0)) / dim_model) # 1000^(2i/dim_model)
# PE(pos, 2i) = sin(pos/1000^(2i/dim_model))
pos_encoding[:, 0::2] = torch.sin(positions_list * division_term)
# PE(pos, 2i + 1) = cos(pos/1000^(2i/dim_model))
pos_encoding[:, 1::2] = torch.cos(positions_list * division_term)
# Saving buffer (same as parameter without gradients needed)
pos_encoding = pos_encoding.unsqueeze(0).transpose(0, 1)
self.register_buffer("pos_encoding", pos_encoding)
def forward(self, token_embedding: torch.tensor) -> torch.tensor:
# Residual connection + pos encoding
return self.dropout(token_embedding + self.pos_encoding[:token_embedding.size(0), :])
class MusicGenerator(nn.Module):
def __init__(self, dim_model: int, vocab_size: int, num_heads: int):
super().__init__()
self.dim_model = dim_model
# LAYERS
self.positional_encoder_layer = PositionalEncoding(
dim_model=dim_model, dropout_p=0.1, max_len=5000
)
self.embedding_layer = nn.Embedding(vocab_size, dim_model)
self.transformer_layer = nn.Transformer(d_model=dim_model, nhead=num_heads)
self.out = nn.Linear(dim_model, vocab_size)
def forward(self, src, tgt):
# Src size must be (batch_size, src sequence length)
# Tgt size must be (batch_size, tgt sequence length)
# Embedding + positional encoding - Out size = (batch_size, sequence length, dim_model)
src = self.embedding_layer(src) * math.sqrt(self.dim_model)
tgt = self.embedding_layer(tgt) * math.sqrt(self.dim_model)
src = self.positional_encoder_layer(src)
tgt = self.positional_encoder_layer(tgt)
# we permute to obtain size (sequence length, batch_size, dim_model),
src = src.permute(1, 0, 2)
tgt = tgt.permute(1, 0, 2)
# Transformer blocks - Out size = (sequence length, batch_size, num_tokens)
transformer_out = self.transformer_layer(src, tgt)
out = self.out(transformer_out)
return out
def get_random_batch(data, block_size=256, batch_size=64):
# generate a small batch of data of inputs x and targets y
ix = torch.randint(len(data) - block_size, (batch_size,))
x = torch.stack([data[i:i + block_size] for i in ix])
y = torch.stack([data[i + 1:i + block_size + 1] for i in ix])
x, y = x.to(device), y.to(device)
return x, y
def train(trainingInputFilePath: str, modelParametersFilePath: str, dim_model: int):
print(f"training from {trainingInputFilePath}")
print(f"read embedding from {embeddingFilePath}")
print(f"save training in {modelParametersFilePath}")
pass
# code temp.
data = ""
vocab = set()
with open(trainingInputFilePath, "r") as f:
while True:
char = f.read(1)
data += char
if not char:
break
vocab.add(char)
n_vocab = len(vocab)
stoi = {ch: i for i, ch in enumerate(vocab)}
itos = {i: ch for i, ch in enumerate(vocab)}
encode = lambda s: [stoi[c] for c in s] # encoder: take a string, output a list of integers
decode = lambda l: ''.join([itos[i] for i in l]) # decoder: take a list of integers, output a string
# Train and test splits
data = torch.tensor(encode(data), dtype=torch.long)
n = int(0.9 * len(data)) # first 90% will be train, rest val
train_data = data[:n]
val_data = data[n:]
print(f"n_vocab = {n_vocab}, dim_model = {dim_model}")
# IMPORTANT : dim_model doit être divisible par num_heads
mg = MusicGenerator(dim_model=dim_model, vocab_size=n_vocab, num_heads=1).to(device)
opt = torch.optim.SGD(mg.parameters(), lr=0.01)
loss_fn = nn.CrossEntropyLoss()
mg.train()
total_loss = 0
for i in range(10):
print(i)
X, y = get_random_batch(train_data)
X, y = torch.tensor(X).to(device), torch.tensor(y).to(device)
# Now we shift the tgt by one so with the <SOS> we predict the token at pos 1
y_input = y[:, :-1]
y_expected = y[:, 1:]
# Standard training except we pass in y_input and tgt_mask
pred = mg(X, y_input)
# Permute pred to have batch size first again
pred = pred.permute(1, 2, 0)
loss = loss_fn(pred, y_expected)
opt.zero_grad()
loss.backward()
opt.step()
total_loss += loss.detach().item()
with torch.no_grad():
print(mg())
# context = torch.zeros((1, 1), dtype=torch.long, device=device)
# print(decode(mg.generate(context, max_new_tokens=500)[0].tolist()))
def generate(outputFilePath: str, bpm: int, modelParametersFilePath: str):
......
import json
import os
import torch
import music21
......@@ -31,7 +32,20 @@ def getMinMaxBpm(folderPath: str, trainingInfoJsonFilename: str) -> (int, int):
return 0, 0
def createEmbedding(embeddingFilePath: str, trainingTextFilePath: str):
print(f"read {trainingTextFilePath}")
print(f"create embedding {embeddingFilePath}")
pass
# en fait il y a pas besoin de ça car pytorch optimize aussi les données dans les nn.Embedding
# mais cette fonction pourrait être utile si on voudrait un embedding plus spécifique
# def createEmbedding(embeddingFilePath: str, trainingTextFilePath: str):
# print(f"read {trainingTextFilePath}")
# print(f"create embedding {embeddingFilePath}")
#
# # example code
# vocab = set()
# with open(trainingTextFilePath, 'r') as file:
# while True:
# char = file.read(1)
# if not char:
# break
# vocab.add(char)
# embedding = {c: [i] for i, c in enumerate(sorted(list(vocab)))}
# print(embedding)
# torch.save(embedding, embeddingFilePath)
music21==9.1.0
numpy==1.26.4
torch==2.3.1
This diff is collapsed.
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment