Skip to content
Snippets Groups Projects
Select Git revision
  • 948796076a4d2374c957fd8390f90aa1bda950bf
  • main default protected
2 results

insert_script.py

Blame
  • insert_script.py 2.85 KiB
    import os
    import ijson
    import argparse
    import time
    from datetime import datetime
    from concurrent.futures import ThreadPoolExecutor, as_completed
    from py2neo import Graph
    from tqdm import tqdm
    
    NEO4J_URI = os.getenv("NEO4J_URI")
    NEO4J_USER = os.getenv("NEO4J_USER")
    NEO4J_PASSWORD = os.getenv("NEO4J_PASSWORD")
    BATCH_SIZE = 500
    MAX_WORKERS = 2
    
    def batch_insert_articles(graph, articles):
        query = """
        UNWIND $articles AS article
        MERGE (a:ARTICLE { _id: article._id })
        SET a.title = article.title
    
        FOREACH (author IN article.authors |
            FOREACH (_ IN CASE WHEN author.name IS NOT NULL THEN [1] ELSE [] END |
                MERGE (p:AUTHOR { _id: coalesce(author._id, author.name) })
                SET p.name = author.name
                MERGE (p)-[:AUTHORED]->(a)
            )
        )
    
        FOREACH (ref_id IN article.references |
            MERGE (r:ARTICLE { _id: ref_id })
            MERGE (a)-[:CITES]->(r)
        )
        """
        graph.run(query, articles=articles)
    
    def main(json_file, limit):
        start_time = time.time()
        print(f"⏱️ Début : {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
    
        graph = Graph(NEO4J_URI, auth=(NEO4J_USER, NEO4J_PASSWORD))
    
        print(f"📄 Lecture optimisée de {json_file} (limite: {limit})")
    
        with open(json_file, 'r', encoding='utf-8') as f:
            article_iter = ijson.items(f, 'item')
            total = 0
            futures = []
    
            def flush_batch(batch):
                if batch:
                    futures.append(executor.submit(batch_insert_articles, graph, list(batch)))
                    batch.clear()
    
            with ThreadPoolExecutor(max_workers=MAX_WORKERS) as executor:
                batch = []
                for article in tqdm(article_iter):
                    if limit and total >= limit:
                        break
                    batch.append(article)
                    total += 1
    
                    if len(batch) >= BATCH_SIZE:
                        flush_batch(batch)
                        time.sleep(0.1)
    
                # envoyer les derniers articles
                flush_batch(batch)
    
                # attendre la fin de tous les threads
                for future in tqdm(as_completed(futures), total=len(futures), desc="💾 Finalisation des insertions"):
                    future.result()
    
        end_time = time.time()
        elapsed_ms = int((end_time - start_time) * 1000)
    
        print(f"✅ Import terminé")
        print(f"⏱️ Fin   : {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
        print(f"🕓 Durée totale : {elapsed_ms:,} ms")
        print(f"⚡ Vitesse moyenne : {int(total / (elapsed_ms / 1000))} it/s")
    
    # --- CLI ---
    if __name__ == "__main__":
        parser = argparse.ArgumentParser()
        parser.add_argument("--file", type=str, default="clean_dblpv13.json", help="Chemin vers le fichier JSON nettoyé")
        parser.add_argument("--limit", type=int, default=1000000, help="Nombre maximum d'articles à charger")
        args = parser.parse_args()
    
        main(args.file, args.limit)