-
Notifications
You must be signed in to change notification settings - Fork 8
OpenAI: Threads #40
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
OpenAI: Threads #40
Changes from all commits
7e4617f
743e5d0
00ebe6f
c24b58f
15aa41c
71bc5a8
8f11fa5
ccf7ed3
2e32a38
5081383
d45984a
941a81c
ad85de8
ee09d3e
5546222
13d913a
146341e
7c79dea
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,140 @@ | ||
| import re | ||
| import requests | ||
|
|
||
| import openai | ||
| from openai import OpenAI | ||
| from fastapi import APIRouter, BackgroundTasks | ||
|
|
||
| from app.utils import APIResponse | ||
| from app.core import settings, logging | ||
|
|
||
| logger = logging.getLogger(__name__) | ||
| router = APIRouter(tags=["threads"]) | ||
|
|
||
|
|
||
| def send_callback(callback_url: str, data: dict): | ||
| """Send results to the callback URL (synchronously).""" | ||
| try: | ||
| session = requests.Session() | ||
| # uncomment this to run locally without SSL | ||
| # session.verify = False | ||
| response = session.post(callback_url, json=data) | ||
| response.raise_for_status() | ||
| return True | ||
| except requests.RequestException as e: | ||
| logger.error(f"Callback failed: {str(e)}") | ||
| return False | ||
|
|
||
|
|
||
| def process_run(request: dict, client: OpenAI): | ||
| """ | ||
| Background task to run create_and_poll, then send the callback with the result. | ||
| This function is run in the background after we have already returned an initial response. | ||
| """ | ||
| try: | ||
| # Start the run | ||
| run = client.beta.threads.runs.create_and_poll( | ||
| thread_id=request["thread_id"], | ||
| assistant_id=request["assistant_id"], | ||
| ) | ||
|
|
||
| if run.status == "completed": | ||
| messages = client.beta.threads.messages.list( | ||
| thread_id=request["thread_id"]) | ||
| latest_message = messages.data[0] | ||
| message_content = latest_message.content[0].text.value | ||
|
|
||
| remove_citation = request.get("remove_citation", False) | ||
|
|
||
| if remove_citation: | ||
| message = re.sub(r"【\d+(?::\d+)?†[^】]*】", "", message_content) | ||
| else: | ||
| message = message_content | ||
|
|
||
| # Update the data dictionary with additional fields from the request, excluding specific keys | ||
| additional_data = {k: v for k, v in request.items( | ||
| ) if k not in {"question", "assistant_id", "callback_url", "thread_id"}} | ||
| callback_response = APIResponse.success_response(data={ | ||
| "status": "success", | ||
| "message": message, | ||
| "thread_id": request["thread_id"], | ||
| "endpoint": getattr(request, "endpoint", "some-default-endpoint"), | ||
| **additional_data | ||
| }) | ||
| else: | ||
| callback_response = APIResponse.failure_response( | ||
| error=f"Run failed with status: {run.status}") | ||
|
|
||
| # Send callback with results | ||
| send_callback(request["callback_url"], callback_response.model_dump()) | ||
|
|
||
| except openai.OpenAIError as e: | ||
| # Handle any other OpenAI API errors | ||
| if isinstance(e.body, dict) and "message" in e.body: | ||
| error_message = e.body["message"] | ||
| else: | ||
| error_message = str(e) | ||
|
|
||
| callback_response = APIResponse.failure_response(error=error_message) | ||
|
|
||
| send_callback(request["callback_url"], callback_response.model_dump()) | ||
|
|
||
|
|
||
| @router.post("/threads") | ||
| async def threads(request: dict, background_tasks: BackgroundTasks): | ||
| """ | ||
| Accepts a question, assistant_id, callback_url, and optional thread_id from the request body. | ||
| Returns an immediate "processing" response, then continues to run create_and_poll in background. | ||
| Once completed, calls send_callback with the final result. | ||
| """ | ||
| client = OpenAI(api_key=settings.OPENAI_API_KEY) | ||
AkhileshNegi marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
| # Use get method to safely access thread_id | ||
| thread_id = request.get("thread_id") | ||
|
|
||
| # 1. Validate or check if there's an existing thread with an in-progress run | ||
| if thread_id: | ||
| try: | ||
| runs = client.beta.threads.runs.list(thread_id=thread_id) | ||
| # Get the most recent run (first in the list) if any | ||
| if runs.data and len(runs.data) > 0: | ||
| latest_run = runs.data[0] | ||
| if latest_run.status in ["queued", "in_progress", "requires_action"]: | ||
| return APIResponse.failure_response(error=f"There is an active run on this thread (status: {latest_run.status}). Please wait for it to complete.") | ||
| except openai.OpenAIError: | ||
| # Handle invalid thread ID | ||
| return APIResponse.failure_response(error=f"Invalid thread ID provided {thread_id}") | ||
|
Comment on lines
+97
to
+106
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Consider avoiding this test in favor of |
||
|
|
||
| # Use existing thread | ||
| client.beta.threads.messages.create( | ||
| thread_id=thread_id, role="user", content=request["question"] | ||
| ) | ||
| else: | ||
| try: | ||
| # Create new thread | ||
| thread = client.beta.threads.create() | ||
| client.beta.threads.messages.create( | ||
| thread_id=thread.id, role="user", content=request["question"] | ||
| ) | ||
| request["thread_id"] = thread.id | ||
| except openai.OpenAIError as e: | ||
| # Handle any other OpenAI API errors | ||
| if isinstance(e.body, dict) and "message" in e.body: | ||
| error_message = e.body["message"] | ||
| else: | ||
| error_message = str(e) | ||
| return APIResponse.failure_response(error=error_message) | ||
|
|
||
| # 2. Send immediate response to complete the API call | ||
| initial_response = APIResponse.success_response(data={ | ||
| "status": "processing", | ||
| "message": "Run started", | ||
| "thread_id": request.get("thread_id"), | ||
| "success": True, | ||
| }) | ||
|
|
||
| # 3. Schedule the background task to run create_and_poll and send callback | ||
| background_tasks.add_task(process_run, request, client) | ||
|
|
||
| # 4. Return immediately so the client knows we've accepted the request | ||
| return initial_response | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,4 @@ | ||
| from .config import settings | ||
| from .logger import logging | ||
|
|
||
| __all__ = ['settings', 'logging'] |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,22 @@ | ||
| import logging | ||
| import os | ||
| from logging.handlers import RotatingFileHandler | ||
| from app.core.config import settings | ||
|
|
||
| LOG_DIR = settings.LOG_DIR | ||
| if not os.path.exists(LOG_DIR): | ||
| os.makedirs(LOG_DIR) | ||
|
Comment on lines
+7
to
+8
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. try:
if not os.path.exists(LOG_DIR):
os.makedirs(LOG_DIR)
except (OSError, IOError) as e:
import sys
sys.stderr.write(f'Failed to create log directory: {e}\n')
raiseAdd error handling around file operations to gracefully handle permission issues or disk space problems when creating log directory and files Talk to Kody by mentioning @kody Was this suggestion helpful? React with 👍 or 👎 to help Kody learn from this interaction. |
||
|
|
||
| LOG_FILE_PATH = os.path.join(LOG_DIR, "app.log") | ||
|
|
||
| LOGGING_LEVEL = logging.INFO | ||
| LOGGING_FORMAT = "%(asctime)s - %(name)s - %(levelname)s - %(message)s" | ||
|
|
||
| logging.basicConfig(level=LOGGING_LEVEL, format=LOGGING_FORMAT) | ||
|
|
||
| file_handler = RotatingFileHandler( | ||
| LOG_FILE_PATH, maxBytes=10485760, backupCount=5) | ||
| file_handler.setLevel(LOGGING_LEVEL) | ||
| file_handler.setFormatter(logging.Formatter(LOGGING_FORMAT)) | ||
|
|
||
| logging.getLogger("").addHandler(file_handler) | ||
|
Comment on lines
+15
to
+22
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. def setup_logging():
logging.basicConfig(level=LOGGING_LEVEL, format=LOGGING_FORMAT)
file_handler = RotatingFileHandler(
LOG_FILE_PATH, maxBytes=10485760, backupCount=5)
file_handler.setLevel(LOGGING_LEVEL)
file_handler.setFormatter(logging.Formatter(LOGGING_FORMAT))
logging.getLogger("").addHandler(file_handler)
setup_logging()Move logging configuration into a function to ensure proper initialization and allow for potential reconfiguration Talk to Kody by mentioning @kody Was this suggestion helpful? React with 👍 or 👎 to help Kody learn from this interaction. |
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,111 @@ | ||
| import pytest | ||
| import openai | ||
|
|
||
| from unittest.mock import MagicMock, patch | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Standard library imports should come first. See PEP8 |
||
| from fastapi import FastAPI | ||
| from fastapi.testclient import TestClient | ||
|
|
||
| from app.api.routes.threads import router, process_run | ||
| from app.utils import APIResponse | ||
|
|
||
| # Wrap the router in a FastAPI app instance. | ||
| app = FastAPI() | ||
| app.include_router(router) | ||
| client = TestClient(app) | ||
|
|
||
|
|
||
| @patch("src.app.api.v1.threads.OpenAI") | ||
| def test_threads_endpoint(mock_openai): | ||
| """ | ||
| Test the /threads endpoint when creating a new thread. | ||
| The patched OpenAI client simulates: | ||
| - A successful assistant ID validation. | ||
| - New thread creation with a dummy thread id. | ||
| - No existing runs. | ||
| The expected response should have status "processing" and include a thread_id. | ||
| """ | ||
| # Create a dummy client to simulate OpenAI API behavior. | ||
| dummy_client = MagicMock() | ||
| # Simulate a valid assistant ID by ensuring retrieve doesn't raise an error. | ||
| dummy_client.beta.assistants.retrieve.return_value = None | ||
| # Simulate thread creation. | ||
| dummy_thread = MagicMock() | ||
| dummy_thread.id = "dummy_thread_id" | ||
| dummy_client.beta.threads.create.return_value = dummy_thread | ||
| # Simulate message creation. | ||
| dummy_client.beta.threads.messages.create.return_value = None | ||
| # Simulate that no active run exists. | ||
| dummy_client.beta.threads.runs.list.return_value = MagicMock(data=[]) | ||
|
|
||
| mock_openai.return_value = dummy_client | ||
|
|
||
| request_data = { | ||
| "question": "What is Glific?", | ||
| "assistant_id": "assistant_123", | ||
| "callback_url": "http://example.com/callback", | ||
| } | ||
| response = client.post("/threads", json=request_data) | ||
| assert response.status_code == 200 | ||
| response_json = response.json() | ||
| assert response_json["success"] is True | ||
| assert response_json["data"]["status"] == "processing" | ||
| assert response_json["data"]["message"] == "Run started" | ||
| assert response_json["data"]["thread_id"] == "dummy_thread_id" | ||
|
|
||
|
|
||
| @patch("src.app.api.v1.threads.OpenAI") | ||
| @pytest.mark.parametrize( | ||
| "remove_citation, expected_message", | ||
| [ | ||
| ( | ||
| True, | ||
| "Glific is an open-source, two-way messaging platform designed for nonprofits to scale their outreach via WhatsApp", | ||
| ), | ||
| ( | ||
| False, | ||
| "Glific is an open-source, two-way messaging platform designed for nonprofits to scale their outreach via WhatsApp【1:2†citation】", | ||
| ), | ||
| ], | ||
| ) | ||
| def test_process_run_variants(mock_openai, remove_citation, expected_message): | ||
| """ | ||
| Test process_run for both remove_citation variants: | ||
| - Mocks the OpenAI client to simulate a completed run. | ||
| - Verifies that send_callback is called with the expected message based on the remove_citation flag. | ||
| """ | ||
| # Setup the mock client. | ||
| mock_client = MagicMock() | ||
| mock_openai.return_value = mock_client | ||
|
|
||
| # Create the request with the variable remove_citation flag. | ||
| request = { | ||
| "question": "What is Glific?", | ||
| "assistant_id": "assistant_123", | ||
| "callback_url": "http://example.com/callback", | ||
| "thread_id": "thread_123", | ||
| "remove_citation": remove_citation, | ||
| } | ||
|
|
||
| # Simulate a completed run. | ||
| mock_run = MagicMock() | ||
| mock_run.status = "completed" | ||
| mock_client.beta.threads.runs.create_and_poll.return_value = mock_run | ||
|
|
||
| # Set up the dummy message based on the remove_citation flag. | ||
| base_message = "Glific is an open-source, two-way messaging platform designed for nonprofits to scale their outreach via WhatsApp" | ||
| citation_message = base_message if remove_citation else f"{base_message}【1:2†citation】" | ||
| dummy_message = MagicMock() | ||
| dummy_message.content = [MagicMock(text=MagicMock(value=citation_message))] | ||
| mock_client.beta.threads.messages.list.return_value.data = [dummy_message] | ||
|
|
||
| # Patch send_callback and invoke process_run. | ||
| with patch("src.app.api.v1.threads.send_callback") as mock_send_callback: | ||
| process_run(request, mock_client) | ||
| mock_send_callback.assert_called_once() | ||
| callback_url, payload = mock_send_callback.call_args[0] | ||
| print(payload) | ||
| assert callback_url == request["callback_url"] | ||
| assert payload["data"]["message"] == expected_message | ||
| assert payload["data"]["status"] == "success" | ||
| assert payload["data"]["thread_id"] == "thread_123" | ||
| assert payload["success"] is True | ||
Uh oh!
There was an error while loading. Please reload this page.