diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 6908dab..c806cff 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -35,7 +35,7 @@ jobs: runs-on: ${{ matrix.os }} strategy: matrix: - os: [ubuntu-latest, macos-13] + os: [ubuntu-latest, macos-15-intel] steps: - uses: actions/checkout@v4 diff --git a/src/agent/profiles/base.py b/src/agent/profiles/base.py index 9a6e26c..f62b6c6 100644 --- a/src/agent/profiles/base.py +++ b/src/agent/profiles/base.py @@ -1,4 +1,4 @@ -from typing import Annotated, TypedDict +from typing import Annotated, Literal, TypedDict from langchain_core.embeddings import Embeddings from langchain_core.language_models.chat_models import BaseChatModel @@ -7,6 +7,7 @@ from langgraph.graph.message import add_messages from agent.tasks.rephrase import create_rephrase_chain +from agent.tasks.safety_checker import create_safety_checker from tools.external_search.state import SearchState, WebSearchResult from tools.external_search.workflow import create_search_workflow @@ -27,6 +28,8 @@ class OutputState(TypedDict, total=False): class BaseState(InputState, OutputState, total=False): rephrased_input: str # LLM-generated query from user input chat_history: Annotated[list[BaseMessage], add_messages] + safety: str # LLM-assessed safety level of the user input + reason_unsafe: str class BaseGraphBuilder: @@ -38,21 +41,36 @@ def __init__( embedding: Embeddings, ) -> None: self.rephrase_chain: Runnable = create_rephrase_chain(llm) + self.safety_checker: Runnable = create_safety_checker(llm) self.search_workflow: Runnable = create_search_workflow(llm) async def preprocess(self, state: BaseState, config: RunnableConfig) -> BaseState: rephrased_input: str = await self.rephrase_chain.ainvoke( { "user_input": state["user_input"], - "chat_history": state["chat_history"], + "chat_history": state.get("chat_history", []), }, config, ) - return BaseState(rephrased_input=rephrased_input) + safety_check: BaseState = await self.safety_checker.ainvoke( + {"rephrased_input": rephrased_input}, + config, + ) + return BaseState( + rephrased_input=rephrased_input, + safety=safety_check["safety"].lower(), + reason_unsafe=safety_check["reason_unsafe"], + ) + + def proceed_with_research(self, state: BaseState) -> Literal["Continue", "Finish"]: + return "Continue" if state["safety"] == "true" else "Finish" async def postprocess(self, state: BaseState, config: RunnableConfig) -> BaseState: search_results: list[WebSearchResult] = [] - if config["configurable"]["enable_postprocess"]: + if ( + config["configurable"].get("enable_postprocess", False) + and state["safety"] == "true" + ): result: SearchState = await self.search_workflow.ainvoke( SearchState( input=state["rephrased_input"], diff --git a/src/agent/profiles/cross_database.py b/src/agent/profiles/cross_database.py index 74ef26c..f96b531 100644 --- a/src/agent/profiles/cross_database.py +++ b/src/agent/profiles/cross_database.py @@ -22,7 +22,6 @@ class CrossDatabaseState(BaseState): - safety: str # LLM-assessed safety level of the user input query_language: str # language of the user input reactome_query: str # LLM-generated query for Reactome @@ -122,14 +121,6 @@ async def check_question_safety( else: return CrossDatabaseState(safety=result.binary_score) - async def proceed_with_research( - self, state: CrossDatabaseState - ) -> Literal["Continue", "Finish"]: - if state["safety"] == "Yes": - return "Continue" - else: - return "Finish" - async def identify_query_language( self, state: CrossDatabaseState, config: RunnableConfig ) -> CrossDatabaseState: diff --git a/src/agent/profiles/react_to_me.py b/src/agent/profiles/react_to_me.py index c162ac7..76364e4 100644 --- a/src/agent/profiles/react_to_me.py +++ b/src/agent/profiles/react_to_me.py @@ -7,6 +7,7 @@ from langgraph.graph.state import StateGraph from agent.profiles.base import BaseGraphBuilder, BaseState +from agent.tasks.unsafe_answer import create_unsafe_answer_generator from retrievers.reactome.rag import create_reactome_rag @@ -23,6 +24,7 @@ def __init__( super().__init__(llm, embedding) # Create runnables (tasks & tools) + self.unsafe_answer_generator = create_unsafe_answer_generator(llm) self.reactome_rag: Runnable = create_reactome_rag( llm, embedding, streaming=True ) @@ -32,25 +34,47 @@ def __init__( # Set up nodes state_graph.add_node("preprocess", self.preprocess) state_graph.add_node("model", self.call_model) + state_graph.add_node("generate_unsafe_response", self.generate_unsafe_response) state_graph.add_node("postprocess", self.postprocess) # Set up edges state_graph.set_entry_point("preprocess") - state_graph.add_edge("preprocess", "model") + state_graph.add_conditional_edges( + "preprocess", + self.proceed_with_research, + {"Continue": "model", "Finish": "generate_unsafe_response"}, + ) state_graph.add_edge("model", "postprocess") + state_graph.add_edge("generate_unsafe_response", "postprocess") state_graph.set_finish_point("postprocess") self.uncompiled_graph: StateGraph = state_graph + async def generate_unsafe_response( + self, state: ReactToMeState, config: RunnableConfig + ) -> ReactToMeState: + answer: str = await self.unsafe_answer_generator.ainvoke( + { + "user_input": state["rephrased_input"], + "reason_unsafe": state["reason_unsafe"], + }, + config, + ) + return ReactToMeState( + chat_history=[ + HumanMessage(state["user_input"]), + AIMessage(answer), + ], + answer=answer, + ) + async def call_model( self, state: ReactToMeState, config: RunnableConfig ) -> ReactToMeState: result: dict[str, Any] = await self.reactome_rag.ainvoke( { "input": state["rephrased_input"], - "chat_history": ( - state["chat_history"] - if state["chat_history"] - else [HumanMessage(state["user_input"])] + "chat_history": state.get( + "chat_history", [HumanMessage(state["user_input"])] ), }, config, diff --git a/src/agent/tasks/safety_checker.py b/src/agent/tasks/safety_checker.py index ac372e3..b997731 100644 --- a/src/agent/tasks/safety_checker.py +++ b/src/agent/tasks/safety_checker.py @@ -4,25 +4,45 @@ from pydantic import BaseModel, Field safety_check_message = """ -You are an expert grader with a strong understanding of ethical guidelines and responsible AI usage. +You are an expert scientific assistant. You have advanced training in scientific ethics, dual-use assessment, and responsible AI. -Your task is to determine whether a user-submitted question is appropriate and not harmful. -Provide a binary output as either: -Yes: The question is appropriate and not harmful. -No: The question is harmful, offensive, inappropriate, or unethical. +Tasks: + +1. Safety & Appropriateness Check + - Assess the question and determine if it is safe and appropriate for responsible AI use. + - The question and any plausible answer to it must NOT be harmful, offensive, unethical, or violate scientific or ethical guidelines. + - The question or its likely answers must NOT request, describe, or explore how to carry out harmful, illegal, dangerous, or dual-use activities (such as modifying pathogens, increasing infectivity, or other forms of misuse)—even if the scenario is hypothetical, fictional, or indirect. + - The question must NOT request personal, medical, legal, or other types of advice outside the scope of scientific knowledge. + - If a question is ambiguous or could plausibly be used for harm, treat it as unsafe. + - Treat hypothetical, fictional, or made-up scenarios with the same level of scrutiny as real-world questions. + +2. Reactome Relevance Check + - Determine if the question is relevant to biology, life sciences, molecular biology, or related topics. + - Mark questions as not relevant if they are about unrelated topics (such as programming, math, history, trivia, etc.). + +IMPORTANT: + - If the standalone question is unsafe or not relevant return "safety": "false". + - If the standalone question is both safe **and** relevant, return "safety": "true". + +Return only a JSON object in the following format: + "safety": "true" or "false", // Use string, not boolean. Mark as "false" if unsafe OR not relevant. + "reason_unsafe": "..." // If 'safety' is false, briefly state the reason """ safety_check_prompt = ChatPromptTemplate.from_messages( [ ("system", safety_check_message), - ("human", "User question: \n\n {input}"), + ("human", "User question: \n\n {rephrased_input}"), ] ) class SafetyCheck(BaseModel): - binary_score: str = Field( - description="Indicates whether the question is appropriate and related to molecular biology. Expected values: 'Yes' or 'No'." + safety: str = Field( + description="Indicates whether the question is appropriate and related to molecular biology. Expected values: 'true' or 'false'." + ) + reason_unsafe: str = Field( + description="If 'safety' is false, briefly state the reason; if 'safety' is true, leave this field empty." ) diff --git a/src/agent/tasks/unsafe_answer.py b/src/agent/tasks/unsafe_answer.py new file mode 100644 index 0000000..42d4ae6 --- /dev/null +++ b/src/agent/tasks/unsafe_answer.py @@ -0,0 +1,40 @@ +from langchain_core.language_models.chat_models import BaseChatModel +from langchain_core.output_parsers import StrOutputParser +from langchain_core.prompts import ChatPromptTemplate +from langchain_core.runnables import Runnable + +unsafe_answer_message = """ + You are an expert scientific assistant operating under the React-to-Me platform. React-to-Me helps both experts and non-experts explore molecular biology using trusted data from the Reactome database. + +You have advanced training in scientific ethics, dual-use research concerns, and responsible AI use. + +You will receive two inputs: +1. The user's question. +2. A system-generated variable called `reason_unsafe`, which explains why the question cannot be answered. + +Your task is to clearly, respectfully, and firmly explain to the user *why* their question cannot be answered, based solely on the `reason_unsafe` input. Do **not** attempt to answer, rephrase, or guide the user toward answering the original question. + +You must: +- Politely explain the refusal, grounded in the `reason_unsafe`. +- Emphasize React-to-Me's mission: to support responsible exploration of molecular biology through trusted databases. +- Suggest examples of appropriate topics (e.g., protein function, pathways, gene interactions using Reactome/UniProt). + +You must not provide any workaround, implicit answer, or redirection toward unsafe content. +""" + +unsafe_answer_prompt = ChatPromptTemplate.from_messages( + [ + ("system", unsafe_answer_message), + ( + "user", + "Question:{user_input}\n\n Reason for unsafe or out of scope: {reason_unsafe}", + ), + ] +) + + +def create_unsafe_answer_generator(llm: BaseChatModel) -> Runnable: + streaming_llm = llm.model_copy(update={"streaming": True}) + return (unsafe_answer_prompt | streaming_llm | StrOutputParser()).with_config( + run_name="unsafe_answer" + )