Skip to content
Snippets Groups Projects
Commit a5728da8 authored by abir.chebbi's avatar abir.chebbi
Browse files

adjust the creation of the vectorDB

parent 65c2ecaa
No related branches found
No related tags found
No related merge requests found
import boto3
import os
import argparse
LOCAL_DIR = "pdfs"
BUCKET_NAME = 'cloud-lecture-nabil-2024-25'
# Initiate S3 client
s3_client = boto3.client('s3')
# Create S3 Bucket
def create_bucket(s3_client, bucket_name):
""" Create an S3 bucket """
print("Creating Bucket")
response = s3_client.create_bucket(
Bucket=BUCKET_NAME,
)
response = s3_client.create_bucket(Bucket=bucket_name)
print(response)
print()
# Function to write files to S3
def write_files(directory, bucket):
def write_files(s3_client, directory, bucket):
for filename in os.listdir(directory):
if filename.endswith(".pdf"): # Check if the file is a PDF
file_path = os.path.join(directory, filename)
......@@ -29,7 +25,15 @@ def write_files(directory, bucket):
)
print(f"{filename} uploaded successfully.")
# Upload PDF files to S3 bucket
print("Writing Items to Bucket")
write_files(LOCAL_DIR, BUCKET_NAME)
def main(bucket_name, local_dir):
s3_client = boto3.client('s3')
create_bucket(s3_client, bucket_name)
write_files(s3_client, local_dir, bucket_name)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Upload PDF files to an S3 bucket")
parser.add_argument("bucket_name", help="The name of the S3 bucket to which the files will be uploaded")
parser.add_argument("LOCAL_DIR", help="The name of the folder to put the pdf files")
args = parser.parse_args()
main(args.bucket_name, args.LOCAL_DIR)
......@@ -2,30 +2,30 @@
import boto3
import botocore
import time
import argparse
client = boto3.client('opensearchserverless')
#service = 'aoss'
Vector_store_name='test-nabil'
def createEncryptionPolicy(client):
"""Creates an encryption policy that matches all collections beginning with test"""
def createEncryptionPolicy(client,policy_name, collection_name):
"""Creates an encryption policy for the specified collection."""
try:
response = client.create_security_policy(
description='Encryption policy for test collections',
name='test-policy',
policy="""
{
description=f'Encryption policy for {collection_name}',
name=policy_name,
policy=f"""
{{
\"Rules\": [
{
{{
\"ResourceType\": \"collection\",
\"Resource\": [
\"collection\/test*\"
\"collection/{collection_name}\"
]
}
}}
],
\"AWSOwnedKey\": true
}
}}
""",
type='encryption'
)
......@@ -39,27 +39,27 @@ def createEncryptionPolicy(client):
raise error
def createNetworkPolicy(client):
"""Creates a network policy that matches all collections beginning with test"""
def createNetworkPolicy(client,policy_name,collection_name):
"""Creates a network policy for the specified collection."""
try:
response = client.create_security_policy(
description='Network policy for Test collections',
name='test-policy',
policy="""
[{
\"Description\":\"Public access for Test collection\",
description=f'Network policy for {collection_name}',
name=policy_name,
policy=f"""
[{{
\"Description\": \"Public access for {collection_name}\",
\"Rules\": [
{
{{
\"ResourceType\": \"dashboard\",
\"Resource\":[\"collection\/test*\"]
},
{
\"Resource\": [\"collection/{collection_name}\"]
}},
{{
\"ResourceType\": \"collection\",
\"Resource\":[\"collection\/test*\"]
}
\"Resource\": [\"collection/{collection_name}\"]
}}
],
\"AllowFromPublic\": true
}]
}}]
""",
type='network'
)
......@@ -73,65 +73,62 @@ def createNetworkPolicy(client):
raise error
def createAccessPolicy(client):
"""Creates a data access policy that matches all collections beginning with test"""
def createAccessPolicy(client, policy_name, collection_name, IAM_USER):
"""Creates a data access policy for the specified collection."""
try:
response = client.create_access_policy(
description='Data access policy for Test collections',
name='test-policy',
policy="""
[{
\"Rules\":[
{
\"Resource\":[
\"index\/test*\/*\"
],
\"Permission\":[
\"aoss:CreateIndex\",
\"aoss:DeleteIndex\",
\"aoss:UpdateIndex\",
\"aoss:DescribeIndex\",
\"aoss:ReadDocument\",
\"aoss:WriteDocument\"
],
\"ResourceType\": \"index\"
},
{
\"Resource\":[
\"collection\/test*\"
policy_content = f"""
[
{{
"Rules": [
{{
"Resource": ["collection/{collection_name}"],
"Permission": [
"aoss:CreateCollectionItems",
"aoss:DeleteCollectionItems",
"aoss:UpdateCollectionItems",
"aoss:DescribeCollectionItems"
],
\"Permission\":[
\"aoss:CreateCollectionItems\",
\"aoss:DeleteCollectionItems\",
\"aoss:UpdateCollectionItems\",
\"aoss:DescribeCollectionItems\"
"ResourceType": "collection"
}},
{{
"Resource": ["index/{collection_name}/*"],
"Permission": [
"aoss:CreateIndex",
"aoss:DeleteIndex",
"aoss:UpdateIndex",
"aoss:DescribeIndex",
"aoss:ReadDocument",
"aoss:WriteDocument"
],
\"ResourceType\": \"collection\"
}
"ResourceType": "index"
}}
],
\"Principal\":[
\"arn:aws:iam::768034348959:user/AbirChebbi\"
"Principal": ["arn:aws:iam::352909266144:user/{IAM_USER}"]
}}
]
}]
""",
"""
response = client.create_access_policy(
description=f'Data access policy for {collection_name}',
name=policy_name,
policy=policy_content,
type='data'
)
print('\nAccess policy created:')
print(response)
except botocore.exceptions.ClientError as error:
if error.response['Error']['Code'] == 'ConflictException':
print(
'[ConflictException] An access policy with this name already exists.')
print('[ConflictException] An access policy with this name already exists.')
else:
raise error
def waitForCollectionCreation(client):
def waitForCollectionCreation(client,collection_name):
"""Waits for the collection to become active"""
time.sleep(40)
time.sleep(30)
response = client.batch_get_collection(
names=['test1'])
names=[collection_name])
print('\nCollection successfully created:')
print(response["collectionDetails"])
# Extract the collection endpoint from the response
......@@ -140,16 +137,22 @@ def waitForCollectionCreation(client):
return final_host
def main():
createEncryptionPolicy(client)
createNetworkPolicy(client)
createAccessPolicy(client)
collection = client.create_collection(name=Vector_store_name,type='VECTORSEARCH')
ENDPOINT= waitForCollectionCreation(client)
def main(collection_name,IAM_USER):
encryption_policy_name = f'{collection_name}-encryption-policy'
network_policy_name = f'{collection_name}-network-policy'
access_policy_name = f'{collection_name}-access-policy'
createEncryptionPolicy(client, encryption_policy_name, collection_name)
createNetworkPolicy(client, network_policy_name, collection_name)
createAccessPolicy(client, access_policy_name, collection_name,IAM_USER)
collection = client.create_collection(name=collection_name,type='VECTORSEARCH')
ENDPOINT= waitForCollectionCreation(client,collection_name)
print("Collection created successfully:", collection)
print("Collection ENDPOINT:", ENDPOINT)
if __name__== "__main__":
main()
\ No newline at end of file
parser = argparse.ArgumentParser(description="Create collection")
parser.add_argument("collection_name", help="The name of the collection")
parser.add_argument("iam_user", help="The iam user")
args = parser.parse_args()
main(args.collection_name,args.iam_user)
......@@ -8,41 +8,30 @@ from opensearchpy import OpenSearch, RequestsHttpConnection, AWSV4SignerAuth
from langchain_community.vectorstores import OpenSearchVectorSearch
import uuid
import json
import argparse
## Local directory for storing PDF files
LOCAL_DIR = "pdfs"
index_name = "cloud_lecture"
## S3_client
s3_client = boto3.client('s3')
## Bucket name where documents are stored
BUCKET_NAME = "cloud-lecture-2023"
## Bedrock client
bedrock_client = boto3.client(service_name="bedrock-runtime")
## Configuration for AWS authentication and OpenSearch client
credentials = boto3.Session().get_credentials()
credentials = boto3.Session(profile_name='master-group-14').get_credentials()
awsauth = AWSV4SignerAuth(credentials, 'us-east-1', 'aoss')
## Vector DB endpoint
host= 'j6phg34iv0f2rlvxwawd.us-east-1.aoss.amazonaws.com'
## Opensearch Client
OpenSearch_client = OpenSearch(
hosts=[{'host': host, 'port': 443}],
http_auth=awsauth,
use_ssl=True,
verify_certs=True,
connection_class=RequestsHttpConnection,
)
## Create Index in Opensearch
def create_index(index_name):
def create_index(client,index_name):
indexBody = {
"settings": {
"index.knn": True
......@@ -62,7 +51,7 @@ def create_index(index_name):
}
try:
create_response = OpenSearch_client.indices.create(index_name, body=indexBody)
create_response = client.indices.create(index_name, body=indexBody)
print('\nCreating index:')
print(create_response)
except Exception as e:
......@@ -101,6 +90,7 @@ def generate_embeddings(bedrock_client, chunks):
# Store generated embeddings into an OpenSearch index.
def store_embeddings(embeddings, texts, meta_data, host, awsauth, index_name):
docsearch = OpenSearchVectorSearch.from_embeddings(
embeddings,
texts,
......@@ -137,14 +127,25 @@ def generate_store_embeddings(bedrock_client, chunks,awsauth,index_name):
## main
def main():
def main(bucket_name, endpoint,index_name):
## Opensearch Client
OpenSearch_client = OpenSearch(
hosts=[{'host': endpoint, 'port': 443}],
http_auth=awsauth,
use_ssl=True,
verify_certs=True,
connection_class=RequestsHttpConnection,
)
download_documents(BUCKET_NAME,LOCAL_DIR)
download_documents(bucket_name,LOCAL_DIR)
loader= PyPDFDirectoryLoader(LOCAL_DIR)
docs = loader.load()
print(docs[1])
chunks = split_text(docs, 1000, 100)
print(chunks[1])
create_index(OpenSearch_client,index_name)
embeddings= generate_embeddings(bedrock_client, chunks)
print(embeddings[1])
texts = [chunk.page_content for chunk in chunks]
......@@ -152,7 +153,7 @@ def main():
meta_data = [{'source': chunk.metadata['source'], 'page': chunk.metadata['page'] + 1} for chunk in chunks]
print(embeddings[1])
print(meta_data[1])
store_embeddings(embeddings, texts, meta_data ,host, awsauth,index_name)
store_embeddings(embeddings, texts, meta_data ,endpoint, awsauth,index_name)
......@@ -163,4 +164,9 @@ def main():
if __name__== "__main__":
main()
parser = argparse.ArgumentParser(description="Process PDF documents and store their embeddings.")
parser.add_argument("bucket_name", help="The S3 bucket name where documents are stored")
parser.add_argument("endpoint", help="The OpenSearch service endpoint")
parser.add_argument("index_name", help="The name of the OpenSearch index")
args = parser.parse_args()
main(args.bucket_name, args.endpoint, args.index_name)
......@@ -107,6 +107,7 @@ def main():
st.session_state.chat_history.append({"role": "user", "content": user_prompt})
# Generate and display answer
print(user_prompt)
embed_question= get_embedding(user_prompt,bedrock_client)
print(embed_question)
sim_results = similarity_search(embed_question, index_name)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment