import os
import argparse
from google.cloud import storage
from google.cloud import aiplatform
from langchain_community.document_loaders import PyPDFDirectoryLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter

# Configuration de l'authentification Google Cloud
os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = "service-account-file.json"

# Configuration du client Google Cloud Storage
storage_client = storage.Client()

# Fonction pour télécharger des documents depuis Google Cloud Storage
def download_documents(bucket_name, local_dir):
    bucket = storage_client.bucket(bucket_name)
    blobs = bucket.list_blobs()
    for blob in blobs:
        if blob.name.endswith('.pdf'):
            local_filename = os.path.join(local_dir, blob.name)
            blob.download_to_filename(local_filename)
            print(f'Downloaded {blob.name} to {local_filename}')

# Fonction pour diviser les pages/textes en morceaux
def split_text(docs, chunk_size, chunk_overlap):
    text_splitter = RecursiveCharacterTextSplitter(chunk_size=chunk_size, chunk_overlap=chunk_overlap)
    chunks = text_splitter.split_documents(docs)
    return chunks

# Fonction pour générer des embeddings
def generate_embeddings(texts):
    # Initialiser Vertex AI
    aiplatform.init(project="mse-test-project-436514", location="us-central1")

    # Remplacer par l'ID de votre endpoint
    endpoint_id = "2223196018688655360"  # Remplacez par l'ID réel

    # Obtenez l'endpoint pour générer des embeddings
    endpoint = aiplatform.Endpoint(endpoint_id)

    # Créez les instances avec le champ "inputs" attendu
    instances = [{"inputs": text} for text in texts]

    # Générez les embeddings via Vertex AI
    response = endpoint.predict(instances=instances)

    # Récupérer et retourner les embeddings
    return response.predictions

# Fonction principale
def main(bucket_name, index_name, local_path):
    download_documents(bucket_name, local_path)
    
    # Charger les documents
    loader = PyPDFDirectoryLoader(local_path)
    docs = loader.load()
    print('Start chunking')
    
    chunks = split_text(docs, 1000, 100)
    texts = [chunk.page_content for chunk in chunks]
    
    print('Start vectorizing')
    embeddings = generate_embeddings(texts)
    
    # Logique de stockage ou de traitement des embeddings
    print('Embeddings generated:', embeddings)

    print('End processing')

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Process PDF documents and store their embeddings.")
    parser.add_argument("--bucket_name", help="The GCS bucket name where documents are stored")
    parser.add_argument("--index_name", help="The name of the index for storing embeddings (if applicable)")
    parser.add_argument("--local_path", help="Local path to store downloaded files")
    args = parser.parse_args()
    main(args.bucket_name, args.index_name, args.local_path)