diff --git a/src/agent/profiles/base.py b/src/agent/profiles/base.py index 9a6e26c..fd7ec68 100644 --- a/src/agent/profiles/base.py +++ b/src/agent/profiles/base.py @@ -1,4 +1,4 @@ -from typing import Annotated, TypedDict +from typing import Annotated, Literal, TypedDict from langchain_core.embeddings import Embeddings from langchain_core.language_models.chat_models import BaseChatModel @@ -6,9 +6,14 @@ from langchain_core.runnables import Runnable, RunnableConfig from langgraph.graph.message import add_messages -from agent.tasks.rephrase import create_rephrase_chain from tools.external_search.state import SearchState, WebSearchResult from tools.external_search.workflow import create_search_workflow +from tools.preprocessing.state import PreprocessingState +from tools.preprocessing.workflow import create_preprocessing_workflow + +SAFETY_SAFE: Literal["true"] = "true" +SAFETY_UNSAFE: Literal["false"] = "false" +DEFAULT_LANGUAGE: str = "English" class AdditionalContent(TypedDict, total=False): @@ -20,39 +25,49 @@ class InputState(TypedDict, total=False): class OutputState(TypedDict, total=False): - answer: str # primary LLM response that is streamed to the user + answer: str # LLM response streamed to the user additional_content: AdditionalContent # sends on graph completion class BaseState(InputState, OutputState, total=False): - rephrased_input: str # LLM-generated query from user input + rephrased_input: ( + str # contextualized, LLM-generated standalone query from user input + ) chat_history: Annotated[list[BaseMessage], add_messages] + safety: str # LLM-assessed safety level of the user input + reason_unsafe: str class BaseGraphBuilder: - # NOTE: Anything that is common to all graph builders goes here def __init__( self, llm: BaseChatModel, embedding: Embeddings, ) -> None: - self.rephrase_chain: Runnable = create_rephrase_chain(llm) + self.preprocessing_workflow: Runnable = create_preprocessing_workflow(llm) self.search_workflow: Runnable = create_search_workflow(llm) async def preprocess(self, state: BaseState, config: RunnableConfig) -> BaseState: - rephrased_input: str = await self.rephrase_chain.ainvoke( - { - "user_input": state["user_input"], - "chat_history": state["chat_history"], - }, + result: PreprocessingState = await self.preprocessing_workflow.ainvoke( + PreprocessingState( + user_input=state["user_input"], + chat_history=state.get("chat_history", []), + ), config, ) - return BaseState(rephrased_input=rephrased_input) + mapped_state = BaseState( + rephrased_input=result.get("rephrased_input", ""), + safety=result.get("safety", SAFETY_SAFE), + reason_unsafe=result.get("reason_unsafe", ""), + ) + return BaseState(**state, **mapped_state) async def postprocess(self, state: BaseState, config: RunnableConfig) -> BaseState: search_results: list[WebSearchResult] = [] - if config["configurable"]["enable_postprocess"]: + if state.get("safety") == SAFETY_SAFE and config.get("configurable", {}).get( + "enable_postprocess", False + ): result: SearchState = await self.search_workflow.ainvoke( SearchState( input=state["rephrased_input"], @@ -62,5 +77,6 @@ async def postprocess(self, state: BaseState, config: RunnableConfig) -> BaseSta ) search_results = result["search_results"] return BaseState( - additional_content=AdditionalContent(search_results=search_results) + **state, + additional_content=AdditionalContent(search_results=search_results), ) diff --git a/src/agent/profiles/cross_database.py b/src/agent/profiles/cross_database.py index 74ef26c..8a725f1 100644 --- a/src/agent/profiles/cross_database.py +++ b/src/agent/profiles/cross_database.py @@ -15,16 +15,12 @@ create_uniprot_rewriter_w_reactome from agent.tasks.cross_database.summarize_reactome_uniprot import \ create_reactome_uniprot_summarizer -from agent.tasks.detect_language import create_language_detector from agent.tasks.safety_checker import SafetyCheck, create_safety_checker from retrievers.reactome.rag import create_reactome_rag from retrievers.uniprot.rag import create_uniprot_rag class CrossDatabaseState(BaseState): - safety: str # LLM-assessed safety level of the user input - query_language: str # language of the user input - reactome_query: str # LLM-generated query for Reactome reactome_answer: str # LLM-generated answer from Reactome reactome_completeness: str # LLM-assessed completeness of the Reactome answer @@ -48,7 +44,6 @@ def __init__( self.safety_checker = create_safety_checker(llm) self.completeness_checker = create_completeness_grader(llm) - self.detect_language = create_language_detector(llm) self.write_reactome_query = create_reactome_rewriter_w_uniprot(llm) self.write_uniprot_query = create_uniprot_rewriter_w_reactome(llm) self.summarize_final_answer = create_reactome_uniprot_summarizer( @@ -60,7 +55,6 @@ def __init__( # Set up nodes state_graph.add_node("check_question_safety", self.check_question_safety) state_graph.add_node("preprocess_question", self.preprocess) - state_graph.add_node("identify_query_language", self.identify_query_language) state_graph.add_node("conduct_research", self.conduct_research) state_graph.add_node("generate_reactome_answer", self.generate_reactome_answer) state_graph.add_node("rewrite_reactome_query", self.rewrite_reactome_query) @@ -74,7 +68,6 @@ def __init__( state_graph.add_node("postprocess", self.postprocess) # Set up edges state_graph.set_entry_point("preprocess_question") - state_graph.add_edge("preprocess_question", "identify_query_language") state_graph.add_edge("preprocess_question", "check_question_safety") state_graph.add_conditional_edges( "check_question_safety", @@ -112,7 +105,7 @@ async def check_question_safety( config, ) if result.binary_score == "No": - inappropriate_input = f"This is the user's question and it is NOT appropriate for you to answer: {state["user_input"]}. \n\n explain that you are unable to answer the question but you can answer questions about topics related to the Reactome Pathway Knowledgebase or UniProt Knowledgebas." + inappropriate_input = f"This is the user's question and it is NOT appropriate for you to answer: {state["user_input"]}. \n\n explain that you are unable to answer the question but you can answer questions about topics related to the Reactome Pathway Knowledgebase or UniProt Knowledgebase." return CrossDatabaseState( safety=result.binary_score, user_input=inappropriate_input, @@ -130,14 +123,6 @@ async def proceed_with_research( else: return "Finish" - async def identify_query_language( - self, state: CrossDatabaseState, config: RunnableConfig - ) -> CrossDatabaseState: - query_language: str = await self.detect_language.ainvoke( - {"user_input": state["user_input"]}, config - ) - return CrossDatabaseState(query_language=query_language) - async def conduct_research( self, state: CrossDatabaseState, config: RunnableConfig ) -> CrossDatabaseState: @@ -256,7 +241,6 @@ async def generate_final_response( final_response: str = await self.summarize_final_answer.ainvoke( { "input": state["rephrased_input"], - "query_language": state["query_language"], "reactome_answer": state["reactome_answer"], "uniprot_answer": state["uniprot_answer"], }, diff --git a/src/agent/profiles/react_to_me.py b/src/agent/profiles/react_to_me.py index c162ac7..54b9bcb 100644 --- a/src/agent/profiles/react_to_me.py +++ b/src/agent/profiles/react_to_me.py @@ -1,4 +1,4 @@ -from typing import Any +from typing import Any, Literal from langchain_core.embeddings import Embeddings from langchain_core.language_models.chat_models import BaseChatModel @@ -6,7 +6,9 @@ from langchain_core.runnables import Runnable, RunnableConfig from langgraph.graph.state import StateGraph -from agent.profiles.base import BaseGraphBuilder, BaseState +from agent.profiles.base import (SAFETY_SAFE, SAFETY_UNSAFE, BaseGraphBuilder, + BaseState) +from agent.tasks.unsafe_answer import create_unsafe_answer_generator from retrievers.reactome.rag import create_reactome_rag @@ -23,6 +25,7 @@ def __init__( super().__init__(llm, embedding) # Create runnables (tasks & tools) + self.unsafe_answer_generator = create_unsafe_answer_generator(llm) self.reactome_rag: Runnable = create_reactome_rag( llm, embedding, streaming=True ) @@ -32,15 +35,62 @@ def __init__( # Set up nodes state_graph.add_node("preprocess", self.preprocess) state_graph.add_node("model", self.call_model) + state_graph.add_node("generate_unsafe_response", self.generate_unsafe_response) state_graph.add_node("postprocess", self.postprocess) # Set up edges state_graph.set_entry_point("preprocess") - state_graph.add_edge("preprocess", "model") + state_graph.add_conditional_edges( + "preprocess", + self.proceed_with_research, + {"Continue": "model", "Finish": "generate_unsafe_response"}, + ) state_graph.add_edge("model", "postprocess") + state_graph.add_edge("generate_unsafe_response", "postprocess") state_graph.set_finish_point("postprocess") self.uncompiled_graph: StateGraph = state_graph + async def proceed_with_research( + self, state: ReactToMeState + ) -> Literal["Continue", "Finish"]: + return "Continue" if state.get("safety") == SAFETY_SAFE else "Finish" + + async def generate_unsafe_response( + self, state: ReactToMeState, config: RunnableConfig + ) -> ReactToMeState: + final_answer_message = await self.unsafe_answer_generator.ainvoke( + { + "user_input": state.get("rephrased_input", state["user_input"]), + "reason_unsafe": state.get("reason_unsafe", ""), + }, + config, + ) + + final_answer = ( + final_answer_message.content + if hasattr(final_answer_message, "content") + else str(final_answer_message) + ) + + history = list(state.get("chat_history", [])) + history.extend( + [ + HumanMessage(state["user_input"]), + ( + final_answer_message + if hasattr(final_answer_message, "content") + else AIMessage(final_answer) + ), + ] + ) + + return ReactToMeState( + chat_history=history, + answer=final_answer, + safety=SAFETY_UNSAFE, + additional_content={"search_results": []}, + ) + async def call_model( self, state: ReactToMeState, config: RunnableConfig ) -> ReactToMeState: @@ -48,18 +98,22 @@ async def call_model( { "input": state["rephrased_input"], "chat_history": ( - state["chat_history"] - if state["chat_history"] + state.get("chat_history") + if state.get("chat_history") else [HumanMessage(state["user_input"])] ), }, config, ) - return ReactToMeState( - chat_history=[ + history = list(state.get("chat_history", [])) + history.extend( + [ HumanMessage(state["user_input"]), AIMessage(result["answer"]), - ], + ] + ) + return ReactToMeState( + chat_history=history, answer=result["answer"], ) diff --git a/src/agent/tasks/safety_checker.py b/src/agent/tasks/safety_checker.py index ac372e3..d5af300 100644 --- a/src/agent/tasks/safety_checker.py +++ b/src/agent/tasks/safety_checker.py @@ -3,28 +3,57 @@ from langchain_core.runnables import Runnable from pydantic import BaseModel, Field -safety_check_message = """ -You are an expert grader with a strong understanding of ethical guidelines and responsible AI usage. - -Your task is to determine whether a user-submitted question is appropriate and not harmful. -Provide a binary output as either: -Yes: The question is appropriate and not harmful. -No: The question is harmful, offensive, inappropriate, or unethical. -""" - -safety_check_prompt = ChatPromptTemplate.from_messages( - [ - ("system", safety_check_message), - ("human", "User question: \n\n {input}"), - ] -) - class SafetyCheck(BaseModel): - binary_score: str = Field( - description="Indicates whether the question is appropriate and related to molecular biology. Expected values: 'Yes' or 'No'." + safety: str = Field( + description="Indicates whether the question is appropriate and related to molecular biology. Expected values: 'true' or 'false'." + ) + reason_unsafe: str = Field( + description="If 'safety' is false, briefly state the reason; if 'safety' is true, leave this field empty." ) def create_safety_checker(llm: BaseChatModel) -> Runnable: - return safety_check_prompt | llm.with_structured_output(SafetyCheck) + """ + Create a safety checker chain. + + Args: + llm: Language model to use + + Returns: + Runnable that takes rephrased_input and returns SafetyCheck + """ + safety_check_message = """ +You are an expert scientific assistant. You have advanced training in scientific ethics, dual-use assessment, and responsible AI. + +Tasks: + +1. Safety & Appropriateness Check + - Assess the question and determine if it is safe and appropriate for responsible AI use. + - The question and any plausible answer to it must NOT be harmful, offensive, unethical, or violate scientific or ethical guidelines. + - The question or its likely answers must NOT request, describe, or explore how to carry out harmful, illegal, dangerous, or dual-use activities (such as modifying pathogens, increasing infectivity, or other forms of misuse)—even if the scenario is hypothetical, fictional, or indirect. + - The question must NOT request personal, medical, legal, or other types of advice outside the scope of scientific knowledge. + - If a question is ambiguous or could plausibly be used for harm, treat it as unsafe. + - Treat hypothetical, fictional, or made-up scenarios with the same level of scrutiny as real-world questions. + +2. Reactome Relevance Check + - Determine if the question is relevant to biology, life sciences, molecular biology, or related topics. + - Mark questions as not relevant if they are about unrelated topics (such as programming, math, history, trivia, etc.). + +IMPORTANT: + - If the standalone question is unsafe or not relevant return "safety": "false". + - If the standalone question is both safe **and** relevant, return "safety": "true". + +Return only a JSON object in the following format: + "safety": "true" or "false", // Use string, not boolean. Mark as "false" if unsafe OR not relevant. + "reason_unsafe": "..." // If 'safety' is false, briefly state the reason +""" + + prompt = ChatPromptTemplate.from_messages( + [ + ("system", safety_check_message), + ("human", "User question: \n\n {rephrased_input}"), + ] + ) + + return prompt | llm.with_structured_output(SafetyCheck) diff --git a/src/agent/tasks/unsafe_answer.py b/src/agent/tasks/unsafe_answer.py new file mode 100644 index 0000000..7bfd8ec --- /dev/null +++ b/src/agent/tasks/unsafe_answer.py @@ -0,0 +1,46 @@ +from langchain_core.language_models.chat_models import BaseChatModel +from langchain_core.prompts import ChatPromptTemplate +from langchain_core.runnables import Runnable + + +def create_unsafe_answer_generator(llm: BaseChatModel) -> Runnable: + """ + Create an unsafe answer generator chain. + + Args: + llm: Language model to use. + + Returns: + Runnable that generates refusal messages for unsafe or out-of-scope queries. + """ + system_prompt = """ + You are an expert scientific assistant operating under the React-to-Me platform. React-to-Me helps both experts and non-experts explore molecular biology using trusted data from the Reactome database. + +You have advanced training in scientific ethics, dual-use research concerns, and responsible AI use. + +You will receive two inputs: +1. The user's question. +2. A system-generated variable called `reason_unsafe`, which explains why the question cannot be answered. + +Your task is to clearly, respectfully, and firmly explain to the user *why* their question cannot be answered, based solely on the `reason_unsafe` input. Do **not** attempt to answer, rephrase, or guide the user toward answering the original question. + +You must: +- Politely explain the refusal, grounded in the `reason_unsafe`. +- Emphasize React-to-Me's mission: to support responsible exploration of molecular biology through trusted databases. +- Suggest examples of appropriate topics (e.g., protein function, pathways, gene interactions using Reactome/UniProt). + +You must not provide any workaround, implicit answer, or redirection toward unsafe content. +""" + streaming_llm = llm.model_copy(update={"streaming": True}) + + prompt = ChatPromptTemplate.from_messages( + [ + ("system", system_prompt), + ( + "user", + "Question:{user_input}\n\n Reason for unsafe or out of scope: {reason_unsafe}", + ), + ] + ) + + return prompt | streaming_llm diff --git a/src/tools/preprocessing/__init__.py b/src/tools/preprocessing/__init__.py new file mode 100644 index 0000000..01a26cf --- /dev/null +++ b/src/tools/preprocessing/__init__.py @@ -0,0 +1,8 @@ +""" +Preprocessing utilities with reusable workflows and state definitions. +""" + +from tools.preprocessing.state import PreprocessingState +from tools.preprocessing.workflow import create_preprocessing_workflow + +__all__ = ["PreprocessingState", "create_preprocessing_workflow"] diff --git a/src/tools/preprocessing/state.py b/src/tools/preprocessing/state.py new file mode 100644 index 0000000..0e9910d --- /dev/null +++ b/src/tools/preprocessing/state.py @@ -0,0 +1,16 @@ +from typing import TypedDict + +from langchain_core.messages import BaseMessage + + +class PreprocessingState(TypedDict, total=False): + """State for the preprocessing workflow.""" + + # Inputs + user_input: str + chat_history: list[BaseMessage] + + # Task outputs + rephrased_input: str + safety: str + reason_unsafe: str diff --git a/src/tools/preprocessing/workflow.py b/src/tools/preprocessing/workflow.py new file mode 100644 index 0000000..fd5c0d0 --- /dev/null +++ b/src/tools/preprocessing/workflow.py @@ -0,0 +1,63 @@ +from typing import Any, Callable + +from langchain_core.language_models.chat_models import BaseChatModel +from langchain_core.runnables import Runnable, RunnableConfig +from langgraph.graph import StateGraph +from langgraph.graph.state import CompiledStateGraph +from langgraph.utils.runnable import RunnableLike + +from agent.tasks.rephrase import create_rephrase_chain +from agent.tasks.safety_checker import create_safety_checker +from tools.preprocessing.state import PreprocessingState + + +def create_task_wrapper( + task: Runnable, + input_mapper: Callable[[PreprocessingState], dict[str, Any]], + output_mapper: Callable[[Any], PreprocessingState], +) -> RunnableLike: + """Wrap a runnable with state mappers.""" + + async def _wrapper( + state: PreprocessingState, config: RunnableConfig + ) -> PreprocessingState: + result = await task.ainvoke(input_mapper(state), config) + return output_mapper(result) + + return _wrapper + + +def create_preprocessing_workflow(llm: BaseChatModel) -> CompiledStateGraph: + """Create the preprocessing workflow with rephrasing and safety checking.""" + + tasks = { + "rephrase_query": ( + create_rephrase_chain(llm), + lambda state: { + "user_input": state["user_input"], + "chat_history": state.get("chat_history", []), + }, + lambda result: PreprocessingState(rephrased_input=result), + ), + "safety_check": ( + create_safety_checker(llm), + lambda state: {"rephrased_input": state["rephrased_input"]}, + lambda result: PreprocessingState( + safety=result.safety.lower(), + reason_unsafe=result.reason_unsafe, + ), + ), + } + + workflow = StateGraph(PreprocessingState) + + for node_name, (task, input_mapper, output_mapper) in tasks.items(): + workflow.add_node( + node_name, create_task_wrapper(task, input_mapper, output_mapper) + ) + + workflow.set_entry_point("rephrase_query") + workflow.add_edge("rephrase_query", "safety_check") + workflow.set_finish_point("safety_check") + + return workflow.compile()