Select Git revision
insert_script.py
insert_script.py 2.85 KiB
import os
import ijson
import argparse
import time
from datetime import datetime
from concurrent.futures import ThreadPoolExecutor, as_completed
from py2neo import Graph
from tqdm import tqdm
NEO4J_URI = os.getenv("NEO4J_URI")
NEO4J_USER = os.getenv("NEO4J_USER")
NEO4J_PASSWORD = os.getenv("NEO4J_PASSWORD")
BATCH_SIZE = 500
MAX_WORKERS = 2
def batch_insert_articles(graph, articles):
query = """
UNWIND $articles AS article
MERGE (a:ARTICLE { _id: article._id })
SET a.title = article.title
FOREACH (author IN article.authors |
FOREACH (_ IN CASE WHEN author.name IS NOT NULL THEN [1] ELSE [] END |
MERGE (p:AUTHOR { _id: coalesce(author._id, author.name) })
SET p.name = author.name
MERGE (p)-[:AUTHORED]->(a)
)
)
FOREACH (ref_id IN article.references |
MERGE (r:ARTICLE { _id: ref_id })
MERGE (a)-[:CITES]->(r)
)
"""
graph.run(query, articles=articles)
def main(json_file, limit):
start_time = time.time()
print(f"⏱️ Début : {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
graph = Graph(NEO4J_URI, auth=(NEO4J_USER, NEO4J_PASSWORD))
print(f"📄 Lecture optimisée de {json_file} (limite: {limit})")
with open(json_file, 'r', encoding='utf-8') as f:
article_iter = ijson.items(f, 'item')
total = 0
futures = []
def flush_batch(batch):
if batch:
futures.append(executor.submit(batch_insert_articles, graph, list(batch)))
batch.clear()
with ThreadPoolExecutor(max_workers=MAX_WORKERS) as executor:
batch = []
for article in tqdm(article_iter):
if limit and total >= limit:
break
batch.append(article)
total += 1
if len(batch) >= BATCH_SIZE:
flush_batch(batch)
time.sleep(0.1)
# envoyer les derniers articles
flush_batch(batch)
# attendre la fin de tous les threads
for future in tqdm(as_completed(futures), total=len(futures), desc="💾 Finalisation des insertions"):
future.result()
end_time = time.time()
elapsed_ms = int((end_time - start_time) * 1000)
print(f"✅ Import terminé")
print(f"⏱️ Fin : {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
print(f"🕓 Durée totale : {elapsed_ms:,} ms")
print(f"⚡ Vitesse moyenne : {int(total / (elapsed_ms / 1000))} it/s")
# --- CLI ---
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--file", type=str, default="clean_dblpv13.json", help="Chemin vers le fichier JSON nettoyé")
parser.add_argument("--limit", type=int, default=1000000, help="Nombre maximum d'articles à charger")
args = parser.parse_args()
main(args.file, args.limit)