Skip to content

Commit

Permalink
chore: run eslint
Browse files Browse the repository at this point in the history
  • Loading branch information
gbone-restore committed Oct 11, 2024
1 parent e32f9ab commit f8e5bd2
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 50 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -33,26 +33,25 @@ def __init__(
self.user_id = user_id
self.max_messages = max_messages # Store max_messages


def _get_full_history(self) -> List[BaseMessage]:
"""Query all messages from DynamoDB for the current session"""
messages: List[BaseMessage] = []
response = self.table.query(
response = self.table.query(
KeyConditionExpression="#pk = :user_id AND begins_with(#sk, :session_prefix)",
FilterExpression="#itemType = :itemType",
ExpressionAttributeNames={
"#pk": "PK",
"#sk": "SK",
"#itemType": "ItemType"
"#itemType": "ItemType",
},
ScanIndexForward=True,
ExpressionAttributeValues={
":user_id": f"USER#{self.user_id}",
":session_prefix": f"SESSION#{self.session_id}",
":itemType": "message"
}
":itemType": "message",
},
)
items = response.get('Items', [])
items = response.get("Items", [])

return items

Expand All @@ -64,13 +63,15 @@ def messages(self) -> List[BaseMessage]:
# Hande case where max_messages is None
if self.max_messages is None:
self.max.messages = len(full_history_items)

# Slice before processing
relevant_items = full_history_items[-self.max_messages:]
relevant_items = full_history_items[-self.max_messages :]

# Use itemgetter and list comprehension
get_history_data = itemgetter('History')
return [_message_from_dict(get_history_data(item) or '') for item in relevant_items]
get_history_data = itemgetter("History")
return [
_message_from_dict(get_history_data(item) or "") for item in relevant_items
]

def add_message(self, message: BaseMessage) -> None:
"""Append the message to the record in DynamoDB"""
Expand All @@ -93,22 +94,20 @@ def add_message(self, message: BaseMessage) -> None:
self.table.update_item(
Key={
"PK": f"USER#{self.user_id}",
"SK": f"SESSION#{self.session_id}"
"SK": f"SESSION#{self.session_id}",
},
UpdateExpression="SET LastUpdateTime = :time",
ConditionExpression="attribute_exists(PK)",
ExpressionAttributeValues={
":time": current_time
}
ExpressionAttributeValues={":time": current_time},
)
except ClientError as err:
if err.response['Error']['Code'] == 'ConditionalCheckFailedException':
if err.response["Error"]["Code"] == "ConditionalCheckFailedException":
# Session doesn't exist, so create a new one
self.table.put_item(
Item={
"PK": f"USER#{self.user_id}",
"SK": f"SESSION#{self.session_id}",
"Title": _message_to_dict(message)
"Title": _message_to_dict(message)
.get("data", {})
.get("content", "<no title>"),
"StartTime": current_time,
Expand All @@ -121,8 +120,6 @@ def add_message(self, message: BaseMessage) -> None:
# If some other error occurs, re-raise the exception
raise



self.table.put_item(
Item={
"PK": f"USER#{self.user_id}",
Expand Down Expand Up @@ -154,15 +151,11 @@ def add_metadata(self, metadata: dict) -> None:
self.table.update_item(
Key={
"PK": f"USER#{self.user_id}",
"SK": f"SESSION#{self.session_id}#{most_recent_history['StartTime']}"
"SK": f"SESSION#{self.session_id}#{most_recent_history['StartTime']}",
},
UpdateExpression="SET #data = :data",
ExpressionAttributeNames={
"#data": "History"
},
ExpressionAttributeValues={
":data": most_recent_history["History"]
}
ExpressionAttributeNames={"#data": "History"},
ExpressionAttributeValues={":data": most_recent_history["History"]},
)

except Exception as err:
Expand Down
51 changes: 26 additions & 25 deletions lib/shared/layers/python-sdk/python/genai_core/sessions.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
table = dynamodb.Table(SESSIONS_TABLE_NAME)
logger = Logger()


def _get_messages_by_session_id(session_id, user_id):
items = []
try:
Expand All @@ -23,33 +24,30 @@ def _get_messages_by_session_id(session_id, user_id):
ExpressionAttributeNames={
"#pk": "PK",
"#sk": "SK",
"#item_type": "ItemType"
"#item_type": "ItemType",
},
ExpressionAttributeValues={
':user_id': f'USER#{user_id}',
':session_prefix': f'SESSION#{session_id}',
':session_type': 'message'
":user_id": f"USER#{user_id}",
":session_prefix": f"SESSION#{session_id}",
":session_type": "message",
},
ScanIndexForward=True
ScanIndexForward=True,
)

items = response.get('Items', [])
items = response.get("Items", [])

# If there are more items, continue querying
while 'LastEvaluatedKey' in response:
while "LastEvaluatedKey" in response:
response = table.query(
KeyConditionExpression="#pk = :user_id AND begins_with(#sk, :session_prefix)",
ExpressionAttributeNames={
"#pk": "PK",
"#sk": "SK"
},
ExpressionAttributeNames={"#pk": "PK", "#sk": "SK"},
ExpressionAttributeValues={
':user_id': f'USER#{user_id}',
':session_prefix': f'SESSION#{session_id}'
":user_id": f"USER#{user_id}",
":session_prefix": f"SESSION#{session_id}",
},
ScanIndexForward=True
ScanIndexForward=True,
)
items.extend(response.get('Items', []))
items.extend(response.get("Items", []))

except ClientError as error:
if error.response["Error"]["Code"] == "ResourceNotFoundException":
Expand All @@ -59,6 +57,7 @@ def _get_messages_by_session_id(session_id, user_id):

return items


def get_session(session_id, user_id):
try:
items = _get_messages_by_session_id(session_id, user_id)
Expand All @@ -72,11 +71,10 @@ def get_session(session_id, user_id):
}

for item in items:
if 'ItemType' in item:
if item['ItemType'] == 'message':
returnItem['History'].append(item['History'])
returnItem['StartTime']= item['StartTime']

if "ItemType" in item:
if item["ItemType"] == "message":
returnItem["History"].append(item["History"])
returnItem["StartTime"] = item["StartTime"]

except ClientError as error:
if error.response["Error"]["Code"] == "ResourceNotFoundException":
Expand All @@ -86,6 +84,7 @@ def get_session(session_id, user_id):

return returnItem


def list_sessions_by_user_id(user_id: str) -> List[Dict[str, Any]]:
"""
List all sessions for a given user ID.
Expand All @@ -106,13 +105,13 @@ def list_sessions_by_user_id(user_id: str) -> List[Dict[str, Any]]:
"ExpressionAttributeNames": {
"#pk": "PK",
"#sk": "SK",
"#item_type": "ItemType"
"#item_type": "ItemType",
},
"ExpressionAttributeValues": {
":user_id": f"USER#{user_id}",
":session_prefix": "SESSION#",
":session_type": "session"
}
":session_type": "session",
},
}

if last_evaluated_key:
Expand All @@ -135,6 +134,7 @@ def list_sessions_by_user_id(user_id: str) -> List[Dict[str, Any]]:

return session_items


def delete_session(session_id, user_id):
try:
session_history = _get_messages_by_session_id(session_id, user_id)
Expand Down Expand Up @@ -170,15 +170,16 @@ def delete_session(session_id, user_id):
return {"id": session_id, "deleted": True}



def delete_user_sessions(user_id):
try:
sessions = list_sessions_by_user_id(user_id) # Get all sessions for the user
ret_value = []

for session in sessions:
# Extract the session ID from the SK (assuming SK is in the format 'SESSION#<session_id>')
session_id = session["SK"].split("#")[1] # Extracting session ID from 'SESSION#<session_id>'
session_id = session["SK"].split("#")[
1
] # Extracting session ID from 'SESSION#<session_id>'

# Delete each session
result = delete_session(session_id, user_id)
Expand Down

0 comments on commit f8e5bd2

Please sign in to comment.