From 30e7dff4d69a0ae2e6e3f35fdc9c0edd8e1342d3 Mon Sep 17 00:00:00 2001 From: Helia Mohammadi Date: Fri, 7 Nov 2025 18:59:52 -0500 Subject: [PATCH 01/12] refactor: centralize preprocessing in BaseGraphBuilder --- src/agent/profiles/base.py | 48 ++++++++++++++++++++++++++------------ 1 file changed, 33 insertions(+), 15 deletions(-) diff --git a/src/agent/profiles/base.py b/src/agent/profiles/base.py index 9a6e26c..ea1a21c 100644 --- a/src/agent/profiles/base.py +++ b/src/agent/profiles/base.py @@ -1,4 +1,4 @@ -from typing import Annotated, TypedDict +from typing import Annotated, Literal, TypedDict from langchain_core.embeddings import Embeddings from langchain_core.language_models.chat_models import BaseChatModel @@ -6,9 +6,14 @@ from langchain_core.runnables import Runnable, RunnableConfig from langgraph.graph.message import add_messages -from agent.tasks.rephrase import create_rephrase_chain from tools.external_search.state import SearchState, WebSearchResult from tools.external_search.workflow import create_search_workflow +from tools.preprocessing.state import PreprocessingState +from tools.preprocessing.workflow import create_preprocessing_workflow + +SAFETY_SAFE: Literal["true"] = "true" +SAFETY_UNSAFE: Literal["false"] = "false" +DEFAULT_LANGUAGE: str = "English" class AdditionalContent(TypedDict, total=False): @@ -20,39 +25,50 @@ class InputState(TypedDict, total=False): class OutputState(TypedDict, total=False): - answer: str # primary LLM response that is streamed to the user + answer: str # LLM response streamed to the user additional_content: AdditionalContent # sends on graph completion class BaseState(InputState, OutputState, total=False): - rephrased_input: str # LLM-generated query from user input + rephrased_input: str # contextualized, LLM-generated standalone query from user input chat_history: Annotated[list[BaseMessage], add_messages] + safety: str # LLM-assessed safety level of the user input + reason_unsafe: str class BaseGraphBuilder: - # NOTE: Anything that is common to all graph builders goes here def __init__( self, llm: BaseChatModel, embedding: Embeddings, ) -> None: - self.rephrase_chain: Runnable = create_rephrase_chain(llm) + self.preprocessing_workflow: Runnable = create_preprocessing_workflow(llm) self.search_workflow: Runnable = create_search_workflow(llm) async def preprocess(self, state: BaseState, config: RunnableConfig) -> BaseState: - rephrased_input: str = await self.rephrase_chain.ainvoke( - { - "user_input": state["user_input"], - "chat_history": state["chat_history"], - }, + result: PreprocessingState = await self.preprocessing_workflow.ainvoke( + PreprocessingState( + user_input=state["user_input"], + chat_history=state.get("chat_history", []), + ), config, ) - return BaseState(rephrased_input=rephrased_input) + mapped_state = BaseState( + rephrased_input=result.get("rephrased_input", ""), + safety=result.get("safety", SAFETY_SAFE), + reason_unsafe=result.get("reason_unsafe", ""), + ) + merged_state = dict(state) + merged_state.update(mapped_state) + return BaseState(**merged_state) async def postprocess(self, state: BaseState, config: RunnableConfig) -> BaseState: search_results: list[WebSearchResult] = [] - if config["configurable"]["enable_postprocess"]: + if ( + state.get("safety") == SAFETY_SAFE + and config["configurable"]["enable_postprocess"] + ): result: SearchState = await self.search_workflow.ainvoke( SearchState( input=state["rephrased_input"], @@ -61,6 +77,8 @@ async def postprocess(self, state: BaseState, config: RunnableConfig) -> BaseSta config=RunnableConfig(callbacks=config["callbacks"]), ) search_results = result["search_results"] - return BaseState( - additional_content=AdditionalContent(search_results=search_results) + merged_state = dict(state) + merged_state.update( + {"additional_content": AdditionalContent(search_results=search_results)} ) + return BaseState(**merged_state) From 7c8f50fec0ec860efe8e8c6a9fa0fcd6932b4878 Mon Sep 17 00:00:00 2001 From: Helia Mohammadi Date: Fri, 7 Nov 2025 19:13:32 -0500 Subject: [PATCH 02/12] feat: route unsafe queries in ReactToMe --- src/agent/profiles/react_to_me.py | 72 ++++++++++++++++++++++++++++--- 1 file changed, 66 insertions(+), 6 deletions(-) diff --git a/src/agent/profiles/react_to_me.py b/src/agent/profiles/react_to_me.py index c162ac7..adb2877 100644 --- a/src/agent/profiles/react_to_me.py +++ b/src/agent/profiles/react_to_me.py @@ -1,4 +1,4 @@ -from typing import Any +from typing import Any, Literal from langchain_core.embeddings import Embeddings from langchain_core.language_models.chat_models import BaseChatModel @@ -6,7 +6,16 @@ from langchain_core.runnables import Runnable, RunnableConfig from langgraph.graph.state import StateGraph -from agent.profiles.base import BaseGraphBuilder, BaseState +from agent.profiles.base import ( + DEFAULT_LANGUAGE, + SAFETY_SAFE, + SAFETY_UNSAFE, + BaseGraphBuilder, + BaseState, +) +from agent.tasks.final_answer_generation.unsafe_question import ( + create_unsafe_answer_generator, +) from retrievers.reactome.rag import create_reactome_rag @@ -23,6 +32,9 @@ def __init__( super().__init__(llm, embedding) # Create runnables (tasks & tools) + streaming_llm = llm.model_copy(update={"streaming": True}) + + self.unsafe_answer_generator = create_unsafe_answer_generator(streaming_llm) self.reactome_rag: Runnable = create_reactome_rag( llm, embedding, streaming=True ) @@ -32,36 +44,84 @@ def __init__( # Set up nodes state_graph.add_node("preprocess", self.preprocess) state_graph.add_node("model", self.call_model) + state_graph.add_node("generate_unsafe_response", self.generate_unsafe_response) state_graph.add_node("postprocess", self.postprocess) # Set up edges state_graph.set_entry_point("preprocess") - state_graph.add_edge("preprocess", "model") + state_graph.add_conditional_edges( + "preprocess", + self.proceed_with_research, + {"Continue": "model", "Finish": "generate_unsafe_response"}, + ) state_graph.add_edge("model", "postprocess") + state_graph.add_edge("generate_unsafe_response", "postprocess") state_graph.set_finish_point("postprocess") self.uncompiled_graph: StateGraph = state_graph + async def proceed_with_research( + self, state: ReactToMeState + ) -> Literal["Continue", "Finish"]: + return "Continue" if state.get("safety") == SAFETY_SAFE else "Finish" + + async def generate_unsafe_response( + self, state: ReactToMeState, config: RunnableConfig + ) -> ReactToMeState: + final_answer_message = await self.unsafe_answer_generator.ainvoke( + { + "language": state.get("detected_language", DEFAULT_LANGUAGE), + "user_input": state.get("rephrased_input", state["user_input"]), + "reason_unsafe": state.get("reason_unsafe", ""), + }, + config, + ) + + final_answer = ( + final_answer_message.content + if hasattr(final_answer_message, "content") + else str(final_answer_message) + ) + + updated_state = dict(state) + updated_state.update( + chat_history=[ + HumanMessage(state["user_input"]), + ( + final_answer_message + if hasattr(final_answer_message, "content") + else AIMessage(final_answer) + ), + ], + answer=final_answer, + safety=SAFETY_UNSAFE, + additional_content={"search_results": []}, + ) + return ReactToMeState(**updated_state) + async def call_model( self, state: ReactToMeState, config: RunnableConfig ) -> ReactToMeState: result: dict[str, Any] = await self.reactome_rag.ainvoke( { "input": state["rephrased_input"], + "expanded_queries": state.get("expanded_queries", []), "chat_history": ( - state["chat_history"] - if state["chat_history"] + state.get("chat_history") + if state.get("chat_history") else [HumanMessage(state["user_input"])] ), }, config, ) - return ReactToMeState( + updated_state = dict(state) + updated_state.update( chat_history=[ HumanMessage(state["user_input"]), AIMessage(result["answer"]), ], answer=result["answer"], ) + return ReactToMeState(**updated_state) def create_reactome_graph( From 734f361ed8da5b8b7965bf9670750f43e51e0043 Mon Sep 17 00:00:00 2001 From: Helia Mohammadi Date: Fri, 7 Nov 2025 19:21:17 -0500 Subject: [PATCH 03/12] feat: add safety routing to ReactToMe graph --- src/agent/profiles/react_to_me.py | 4 +-- src/agent/tasks/unsafe_answer.py | 48 +++++++++++++++++++++++++++++++ 2 files changed, 49 insertions(+), 3 deletions(-) create mode 100644 src/agent/tasks/unsafe_answer.py diff --git a/src/agent/profiles/react_to_me.py b/src/agent/profiles/react_to_me.py index adb2877..863a264 100644 --- a/src/agent/profiles/react_to_me.py +++ b/src/agent/profiles/react_to_me.py @@ -13,9 +13,7 @@ BaseGraphBuilder, BaseState, ) -from agent.tasks.final_answer_generation.unsafe_question import ( - create_unsafe_answer_generator, -) +from agent.tasks.unsafe_answer import create_unsafe_answer_generator from retrievers.reactome.rag import create_reactome_rag diff --git a/src/agent/tasks/unsafe_answer.py b/src/agent/tasks/unsafe_answer.py new file mode 100644 index 0000000..27233b3 --- /dev/null +++ b/src/agent/tasks/unsafe_answer.py @@ -0,0 +1,48 @@ +from langchain_core.language_models.chat_models import BaseChatModel +from langchain_core.prompts import ChatPromptTemplate +from langchain_core.runnables import Runnable + + +def create_unsafe_answer_generator(llm: BaseChatModel) -> Runnable: + """ + Create an unsafe answer generator chain. + + Args: + llm: Language model to use. + + Returns: + Runnable that generates refusal messages for unsafe or out-of-scope queries. + """ + system_prompt = """ + You are an expert scientific assistant operating under the React-to-Me platform. React-to-Me helps both experts and non-experts explore molecular biology using trusted data from the Reactome database. + +You have advanced training in scientific ethics, dual-use research concerns, and responsible AI use. + +You will receive three inputs: +1. The user's question. +2. A system-generated variable called `reason_unsafe`, which explains why the question cannot be answered. +3. The user's preferred language (as a language code or name). + +Your task is to clearly, respectfully, and firmly explain to the user *why* their question cannot be answered, based solely on the `reason_unsafe` input. Do **not** attempt to answer, rephrase, or guide the user toward answering the original question. + +You must: +- Respond in the user’s preferred language. +- Politely explain the refusal, grounded in the `reason_unsafe`. +- Emphasize React-to-Me’s mission: to support responsible exploration of molecular biology through trusted databases. +- Suggest examples of appropriate topics (e.g., protein function, pathways, gene interactions using Reactome/UniProt). + +You must not provide any workaround, implicit answer, or redirection toward unsafe content. +""" + prompt = ChatPromptTemplate.from_messages( + [ + ("system", system_prompt), + ( + "user", + "Language:{language}\n\nQuestion:{user_input}\n\n Reason for unsafe or out of scope: {reason_unsafe}", + ), + ] + ) + + return prompt | llm + + From fdc5cd24353df9ee7b2b8e03a4adf18f9d8e7ab0 Mon Sep 17 00:00:00 2001 From: Helia Mohammadi Date: Fri, 7 Nov 2025 19:22:54 -0500 Subject: [PATCH 04/12] feat: strengthen safety checker policy --- src/agent/tasks/safety_checker.py | 67 ++++++++++++++++++++++--------- 1 file changed, 48 insertions(+), 19 deletions(-) diff --git a/src/agent/tasks/safety_checker.py b/src/agent/tasks/safety_checker.py index ac372e3..d5af300 100644 --- a/src/agent/tasks/safety_checker.py +++ b/src/agent/tasks/safety_checker.py @@ -3,28 +3,57 @@ from langchain_core.runnables import Runnable from pydantic import BaseModel, Field -safety_check_message = """ -You are an expert grader with a strong understanding of ethical guidelines and responsible AI usage. - -Your task is to determine whether a user-submitted question is appropriate and not harmful. -Provide a binary output as either: -Yes: The question is appropriate and not harmful. -No: The question is harmful, offensive, inappropriate, or unethical. -""" - -safety_check_prompt = ChatPromptTemplate.from_messages( - [ - ("system", safety_check_message), - ("human", "User question: \n\n {input}"), - ] -) - class SafetyCheck(BaseModel): - binary_score: str = Field( - description="Indicates whether the question is appropriate and related to molecular biology. Expected values: 'Yes' or 'No'." + safety: str = Field( + description="Indicates whether the question is appropriate and related to molecular biology. Expected values: 'true' or 'false'." + ) + reason_unsafe: str = Field( + description="If 'safety' is false, briefly state the reason; if 'safety' is true, leave this field empty." ) def create_safety_checker(llm: BaseChatModel) -> Runnable: - return safety_check_prompt | llm.with_structured_output(SafetyCheck) + """ + Create a safety checker chain. + + Args: + llm: Language model to use + + Returns: + Runnable that takes rephrased_input and returns SafetyCheck + """ + safety_check_message = """ +You are an expert scientific assistant. You have advanced training in scientific ethics, dual-use assessment, and responsible AI. + +Tasks: + +1. Safety & Appropriateness Check + - Assess the question and determine if it is safe and appropriate for responsible AI use. + - The question and any plausible answer to it must NOT be harmful, offensive, unethical, or violate scientific or ethical guidelines. + - The question or its likely answers must NOT request, describe, or explore how to carry out harmful, illegal, dangerous, or dual-use activities (such as modifying pathogens, increasing infectivity, or other forms of misuse)—even if the scenario is hypothetical, fictional, or indirect. + - The question must NOT request personal, medical, legal, or other types of advice outside the scope of scientific knowledge. + - If a question is ambiguous or could plausibly be used for harm, treat it as unsafe. + - Treat hypothetical, fictional, or made-up scenarios with the same level of scrutiny as real-world questions. + +2. Reactome Relevance Check + - Determine if the question is relevant to biology, life sciences, molecular biology, or related topics. + - Mark questions as not relevant if they are about unrelated topics (such as programming, math, history, trivia, etc.). + +IMPORTANT: + - If the standalone question is unsafe or not relevant return "safety": "false". + - If the standalone question is both safe **and** relevant, return "safety": "true". + +Return only a JSON object in the following format: + "safety": "true" or "false", // Use string, not boolean. Mark as "false" if unsafe OR not relevant. + "reason_unsafe": "..." // If 'safety' is false, briefly state the reason +""" + + prompt = ChatPromptTemplate.from_messages( + [ + ("system", safety_check_message), + ("human", "User question: \n\n {rephrased_input}"), + ] + ) + + return prompt | llm.with_structured_output(SafetyCheck) From c8f9f7105adab26dcec4de8e76338ccccc244480 Mon Sep 17 00:00:00 2001 From: Helia Mohammadi Date: Mon, 10 Nov 2025 11:06:48 -0500 Subject: [PATCH 05/12] feat: add preprocessing workflow module --- src/tools/preprocessing/__init__.py | 8 ++++ src/tools/preprocessing/state.py | 20 +++++++++ src/tools/preprocessing/workflow.py | 65 +++++++++++++++++++++++++++++ 3 files changed, 93 insertions(+) create mode 100644 src/tools/preprocessing/__init__.py create mode 100644 src/tools/preprocessing/state.py create mode 100644 src/tools/preprocessing/workflow.py diff --git a/src/tools/preprocessing/__init__.py b/src/tools/preprocessing/__init__.py new file mode 100644 index 0000000..84072a5 --- /dev/null +++ b/src/tools/preprocessing/__init__.py @@ -0,0 +1,8 @@ +""" +Preprocessing utilities with reusable workflows and state definitions. +""" + +from .state import PreprocessingState # noqa: F401 +from .workflow import create_preprocessing_workflow # noqa: F401 + + diff --git a/src/tools/preprocessing/state.py b/src/tools/preprocessing/state.py new file mode 100644 index 0000000..c3836ed --- /dev/null +++ b/src/tools/preprocessing/state.py @@ -0,0 +1,20 @@ +from typing import TypedDict + +from langchain_core.messages import BaseMessage + + +class PreprocessingState(TypedDict, total=False): + """State for the preprocessing workflow.""" + + # Inputs + user_input: str + chat_history: list[BaseMessage] + + # Task outputs + rephrased_input: str + safety: str + reason_unsafe: str + expanded_queries: list[str] + detected_language: str + + diff --git a/src/tools/preprocessing/workflow.py b/src/tools/preprocessing/workflow.py new file mode 100644 index 0000000..4388863 --- /dev/null +++ b/src/tools/preprocessing/workflow.py @@ -0,0 +1,65 @@ +from typing import Any, Callable + +from langchain_core.language_models.chat_models import BaseChatModel +from langchain_core.runnables import Runnable, RunnableConfig +from langgraph.graph import StateGraph +from langgraph.graph.state import CompiledStateGraph +from langgraph.utils.runnable import RunnableLike + +from agent.tasks.rephrase import create_rephrase_chain +from agent.tasks.safety_checker import create_safety_checker +from tools.preprocessing.state import PreprocessingState + + +def create_task_wrapper( + task: Runnable, + input_mapper: Callable[[PreprocessingState], dict[str, Any]], + output_mapper: Callable[[Any], PreprocessingState], +) -> RunnableLike: + """Wrap a runnable with state mappers.""" + + async def _wrapper( + state: PreprocessingState, config: RunnableConfig + ) -> PreprocessingState: + result = await task.ainvoke(input_mapper(state), config) + return output_mapper(result) + + return _wrapper + + +def create_preprocessing_workflow(llm: BaseChatModel) -> CompiledStateGraph: + """Create the preprocessing workflow with rephrasing and safety checking.""" + + tasks = { + "rephrase_query": ( + create_rephrase_chain(llm), + lambda state: { + "user_input": state["user_input"], + "chat_history": state.get("chat_history", []), + }, + lambda result: PreprocessingState(rephrased_input=result), + ), + "safety_check": ( + create_safety_checker(llm), + lambda state: {"rephrased_input": state["rephrased_input"]}, + lambda result: PreprocessingState( + safety=result.safety.lower(), + reason_unsafe=result.reason_unsafe, + ), + ), + } + + workflow = StateGraph(PreprocessingState) + + for node_name, (task, input_mapper, output_mapper) in tasks.items(): + workflow.add_node( + node_name, create_task_wrapper(task, input_mapper, output_mapper) + ) + + workflow.set_entry_point("rephrase_query") + workflow.add_edge("rephrase_query", "safety_check") + workflow.set_finish_point("safety_check") + + return workflow.compile() + + From 6378148edc40b08b961cd20d06022ba62a7354ec Mon Sep 17 00:00:00 2001 From: Helia Mohammadi Date: Mon, 10 Nov 2025 11:31:24 -0500 Subject: [PATCH 06/12] Apply formatting fixes after lint run --- src/agent/profiles/base.py | 6 ++++-- src/agent/profiles/react_to_me.py | 9 ++------- src/agent/tasks/unsafe_answer.py | 2 -- src/tools/preprocessing/__init__.py | 2 -- src/tools/preprocessing/state.py | 2 -- src/tools/preprocessing/workflow.py | 2 -- 6 files changed, 6 insertions(+), 17 deletions(-) diff --git a/src/agent/profiles/base.py b/src/agent/profiles/base.py index ea1a21c..97b2017 100644 --- a/src/agent/profiles/base.py +++ b/src/agent/profiles/base.py @@ -30,10 +30,12 @@ class OutputState(TypedDict, total=False): class BaseState(InputState, OutputState, total=False): - rephrased_input: str # contextualized, LLM-generated standalone query from user input + rephrased_input: ( + str # contextualized, LLM-generated standalone query from user input + ) chat_history: Annotated[list[BaseMessage], add_messages] safety: str # LLM-assessed safety level of the user input - reason_unsafe: str + reason_unsafe: str class BaseGraphBuilder: diff --git a/src/agent/profiles/react_to_me.py b/src/agent/profiles/react_to_me.py index 863a264..1720c71 100644 --- a/src/agent/profiles/react_to_me.py +++ b/src/agent/profiles/react_to_me.py @@ -6,13 +6,8 @@ from langchain_core.runnables import Runnable, RunnableConfig from langgraph.graph.state import StateGraph -from agent.profiles.base import ( - DEFAULT_LANGUAGE, - SAFETY_SAFE, - SAFETY_UNSAFE, - BaseGraphBuilder, - BaseState, -) +from agent.profiles.base import (DEFAULT_LANGUAGE, SAFETY_SAFE, SAFETY_UNSAFE, + BaseGraphBuilder, BaseState) from agent.tasks.unsafe_answer import create_unsafe_answer_generator from retrievers.reactome.rag import create_reactome_rag diff --git a/src/agent/tasks/unsafe_answer.py b/src/agent/tasks/unsafe_answer.py index 27233b3..cd90aa7 100644 --- a/src/agent/tasks/unsafe_answer.py +++ b/src/agent/tasks/unsafe_answer.py @@ -44,5 +44,3 @@ def create_unsafe_answer_generator(llm: BaseChatModel) -> Runnable: ) return prompt | llm - - diff --git a/src/tools/preprocessing/__init__.py b/src/tools/preprocessing/__init__.py index 84072a5..9b2d5ce 100644 --- a/src/tools/preprocessing/__init__.py +++ b/src/tools/preprocessing/__init__.py @@ -4,5 +4,3 @@ from .state import PreprocessingState # noqa: F401 from .workflow import create_preprocessing_workflow # noqa: F401 - - diff --git a/src/tools/preprocessing/state.py b/src/tools/preprocessing/state.py index c3836ed..430a554 100644 --- a/src/tools/preprocessing/state.py +++ b/src/tools/preprocessing/state.py @@ -16,5 +16,3 @@ class PreprocessingState(TypedDict, total=False): reason_unsafe: str expanded_queries: list[str] detected_language: str - - diff --git a/src/tools/preprocessing/workflow.py b/src/tools/preprocessing/workflow.py index 4388863..fd5c0d0 100644 --- a/src/tools/preprocessing/workflow.py +++ b/src/tools/preprocessing/workflow.py @@ -61,5 +61,3 @@ def create_preprocessing_workflow(llm: BaseChatModel) -> CompiledStateGraph: workflow.set_finish_point("safety_check") return workflow.compile() - - From cf347f94db3ae724be771f7e8043c373a77d87ca Mon Sep 17 00:00:00 2001 From: Helia Mohammadi Date: Mon, 10 Nov 2025 11:47:32 -0500 Subject: [PATCH 07/12] style: simplify state merges in BaseGraphBuilder --- src/agent/profiles/base.py | 11 ++++------- 1 file changed, 4 insertions(+), 7 deletions(-) diff --git a/src/agent/profiles/base.py b/src/agent/profiles/base.py index 97b2017..0560c39 100644 --- a/src/agent/profiles/base.py +++ b/src/agent/profiles/base.py @@ -61,9 +61,7 @@ async def preprocess(self, state: BaseState, config: RunnableConfig) -> BaseStat safety=result.get("safety", SAFETY_SAFE), reason_unsafe=result.get("reason_unsafe", ""), ) - merged_state = dict(state) - merged_state.update(mapped_state) - return BaseState(**merged_state) + return BaseState(**state, **mapped_state) async def postprocess(self, state: BaseState, config: RunnableConfig) -> BaseState: search_results: list[WebSearchResult] = [] @@ -79,8 +77,7 @@ async def postprocess(self, state: BaseState, config: RunnableConfig) -> BaseSta config=RunnableConfig(callbacks=config["callbacks"]), ) search_results = result["search_results"] - merged_state = dict(state) - merged_state.update( - {"additional_content": AdditionalContent(search_results=search_results)} + return BaseState( + **state, + additional_content=AdditionalContent(search_results=search_results), ) - return BaseState(**merged_state) From bbc654988c43e4f5a1c5fd361a430b053e393537 Mon Sep 17 00:00:00 2001 From: Helia Mohammadi Date: Fri, 21 Nov 2025 11:49:37 -0500 Subject: [PATCH 08/12] fix: address reviewer import and streaming feedback --- src/agent/profiles/react_to_me.py | 31 +++++++++++++++++------------ src/agent/tasks/unsafe_answer.py | 4 +++- src/tools/preprocessing/__init__.py | 6 ++++-- 3 files changed, 25 insertions(+), 16 deletions(-) diff --git a/src/agent/profiles/react_to_me.py b/src/agent/profiles/react_to_me.py index 1720c71..8af121c 100644 --- a/src/agent/profiles/react_to_me.py +++ b/src/agent/profiles/react_to_me.py @@ -25,9 +25,7 @@ def __init__( super().__init__(llm, embedding) # Create runnables (tasks & tools) - streaming_llm = llm.model_copy(update={"streaming": True}) - - self.unsafe_answer_generator = create_unsafe_answer_generator(streaming_llm) + self.unsafe_answer_generator = create_unsafe_answer_generator(llm) self.reactome_rag: Runnable = create_reactome_rag( llm, embedding, streaming=True ) @@ -75,21 +73,25 @@ async def generate_unsafe_response( else str(final_answer_message) ) - updated_state = dict(state) - updated_state.update( - chat_history=[ + history = list(state.get("chat_history", [])) + history.extend( + [ HumanMessage(state["user_input"]), ( final_answer_message if hasattr(final_answer_message, "content") else AIMessage(final_answer) ), - ], + ] + ) + + return ReactToMeState( + **state, + chat_history=history, answer=final_answer, safety=SAFETY_UNSAFE, additional_content={"search_results": []}, ) - return ReactToMeState(**updated_state) async def call_model( self, state: ReactToMeState, config: RunnableConfig @@ -106,15 +108,18 @@ async def call_model( }, config, ) - updated_state = dict(state) - updated_state.update( - chat_history=[ + history = list(state.get("chat_history", [])) + history.extend( + [ HumanMessage(state["user_input"]), AIMessage(result["answer"]), - ], + ] + ) + return ReactToMeState( + **state, + chat_history=history, answer=result["answer"], ) - return ReactToMeState(**updated_state) def create_reactome_graph( diff --git a/src/agent/tasks/unsafe_answer.py b/src/agent/tasks/unsafe_answer.py index cd90aa7..ff83fe6 100644 --- a/src/agent/tasks/unsafe_answer.py +++ b/src/agent/tasks/unsafe_answer.py @@ -33,6 +33,8 @@ def create_unsafe_answer_generator(llm: BaseChatModel) -> Runnable: You must not provide any workaround, implicit answer, or redirection toward unsafe content. """ + streaming_llm = llm.model_copy(update={"streaming": True}) + prompt = ChatPromptTemplate.from_messages( [ ("system", system_prompt), @@ -43,4 +45,4 @@ def create_unsafe_answer_generator(llm: BaseChatModel) -> Runnable: ] ) - return prompt | llm + return prompt | streaming_llm diff --git a/src/tools/preprocessing/__init__.py b/src/tools/preprocessing/__init__.py index 9b2d5ce..01a26cf 100644 --- a/src/tools/preprocessing/__init__.py +++ b/src/tools/preprocessing/__init__.py @@ -2,5 +2,7 @@ Preprocessing utilities with reusable workflows and state definitions. """ -from .state import PreprocessingState # noqa: F401 -from .workflow import create_preprocessing_workflow # noqa: F401 +from tools.preprocessing.state import PreprocessingState +from tools.preprocessing.workflow import create_preprocessing_workflow + +__all__ = ["PreprocessingState", "create_preprocessing_workflow"] From 855ad4d24ba369c584af5f4c0b58dce50881a02a Mon Sep 17 00:00:00 2001 From: Helia Mohammadi Date: Fri, 21 Nov 2025 20:07:56 -0500 Subject: [PATCH 09/12] fix: address reviewer feedback and resolve state management issues --- src/agent/profiles/base.py | 5 ++--- src/agent/profiles/cross_database.py | 18 +----------------- src/agent/profiles/react_to_me.py | 8 ++------ src/agent/tasks/unsafe_answer.py | 8 +++----- src/tools/preprocessing/state.py | 2 -- 5 files changed, 8 insertions(+), 33 deletions(-) diff --git a/src/agent/profiles/base.py b/src/agent/profiles/base.py index 0560c39..fd7ec68 100644 --- a/src/agent/profiles/base.py +++ b/src/agent/profiles/base.py @@ -65,9 +65,8 @@ async def preprocess(self, state: BaseState, config: RunnableConfig) -> BaseStat async def postprocess(self, state: BaseState, config: RunnableConfig) -> BaseState: search_results: list[WebSearchResult] = [] - if ( - state.get("safety") == SAFETY_SAFE - and config["configurable"]["enable_postprocess"] + if state.get("safety") == SAFETY_SAFE and config.get("configurable", {}).get( + "enable_postprocess", False ): result: SearchState = await self.search_workflow.ainvoke( SearchState( diff --git a/src/agent/profiles/cross_database.py b/src/agent/profiles/cross_database.py index 74ef26c..8a725f1 100644 --- a/src/agent/profiles/cross_database.py +++ b/src/agent/profiles/cross_database.py @@ -15,16 +15,12 @@ create_uniprot_rewriter_w_reactome from agent.tasks.cross_database.summarize_reactome_uniprot import \ create_reactome_uniprot_summarizer -from agent.tasks.detect_language import create_language_detector from agent.tasks.safety_checker import SafetyCheck, create_safety_checker from retrievers.reactome.rag import create_reactome_rag from retrievers.uniprot.rag import create_uniprot_rag class CrossDatabaseState(BaseState): - safety: str # LLM-assessed safety level of the user input - query_language: str # language of the user input - reactome_query: str # LLM-generated query for Reactome reactome_answer: str # LLM-generated answer from Reactome reactome_completeness: str # LLM-assessed completeness of the Reactome answer @@ -48,7 +44,6 @@ def __init__( self.safety_checker = create_safety_checker(llm) self.completeness_checker = create_completeness_grader(llm) - self.detect_language = create_language_detector(llm) self.write_reactome_query = create_reactome_rewriter_w_uniprot(llm) self.write_uniprot_query = create_uniprot_rewriter_w_reactome(llm) self.summarize_final_answer = create_reactome_uniprot_summarizer( @@ -60,7 +55,6 @@ def __init__( # Set up nodes state_graph.add_node("check_question_safety", self.check_question_safety) state_graph.add_node("preprocess_question", self.preprocess) - state_graph.add_node("identify_query_language", self.identify_query_language) state_graph.add_node("conduct_research", self.conduct_research) state_graph.add_node("generate_reactome_answer", self.generate_reactome_answer) state_graph.add_node("rewrite_reactome_query", self.rewrite_reactome_query) @@ -74,7 +68,6 @@ def __init__( state_graph.add_node("postprocess", self.postprocess) # Set up edges state_graph.set_entry_point("preprocess_question") - state_graph.add_edge("preprocess_question", "identify_query_language") state_graph.add_edge("preprocess_question", "check_question_safety") state_graph.add_conditional_edges( "check_question_safety", @@ -112,7 +105,7 @@ async def check_question_safety( config, ) if result.binary_score == "No": - inappropriate_input = f"This is the user's question and it is NOT appropriate for you to answer: {state["user_input"]}. \n\n explain that you are unable to answer the question but you can answer questions about topics related to the Reactome Pathway Knowledgebase or UniProt Knowledgebas." + inappropriate_input = f"This is the user's question and it is NOT appropriate for you to answer: {state["user_input"]}. \n\n explain that you are unable to answer the question but you can answer questions about topics related to the Reactome Pathway Knowledgebase or UniProt Knowledgebase." return CrossDatabaseState( safety=result.binary_score, user_input=inappropriate_input, @@ -130,14 +123,6 @@ async def proceed_with_research( else: return "Finish" - async def identify_query_language( - self, state: CrossDatabaseState, config: RunnableConfig - ) -> CrossDatabaseState: - query_language: str = await self.detect_language.ainvoke( - {"user_input": state["user_input"]}, config - ) - return CrossDatabaseState(query_language=query_language) - async def conduct_research( self, state: CrossDatabaseState, config: RunnableConfig ) -> CrossDatabaseState: @@ -256,7 +241,6 @@ async def generate_final_response( final_response: str = await self.summarize_final_answer.ainvoke( { "input": state["rephrased_input"], - "query_language": state["query_language"], "reactome_answer": state["reactome_answer"], "uniprot_answer": state["uniprot_answer"], }, diff --git a/src/agent/profiles/react_to_me.py b/src/agent/profiles/react_to_me.py index 8af121c..54b9bcb 100644 --- a/src/agent/profiles/react_to_me.py +++ b/src/agent/profiles/react_to_me.py @@ -6,8 +6,8 @@ from langchain_core.runnables import Runnable, RunnableConfig from langgraph.graph.state import StateGraph -from agent.profiles.base import (DEFAULT_LANGUAGE, SAFETY_SAFE, SAFETY_UNSAFE, - BaseGraphBuilder, BaseState) +from agent.profiles.base import (SAFETY_SAFE, SAFETY_UNSAFE, BaseGraphBuilder, + BaseState) from agent.tasks.unsafe_answer import create_unsafe_answer_generator from retrievers.reactome.rag import create_reactome_rag @@ -60,7 +60,6 @@ async def generate_unsafe_response( ) -> ReactToMeState: final_answer_message = await self.unsafe_answer_generator.ainvoke( { - "language": state.get("detected_language", DEFAULT_LANGUAGE), "user_input": state.get("rephrased_input", state["user_input"]), "reason_unsafe": state.get("reason_unsafe", ""), }, @@ -86,7 +85,6 @@ async def generate_unsafe_response( ) return ReactToMeState( - **state, chat_history=history, answer=final_answer, safety=SAFETY_UNSAFE, @@ -99,7 +97,6 @@ async def call_model( result: dict[str, Any] = await self.reactome_rag.ainvoke( { "input": state["rephrased_input"], - "expanded_queries": state.get("expanded_queries", []), "chat_history": ( state.get("chat_history") if state.get("chat_history") @@ -116,7 +113,6 @@ async def call_model( ] ) return ReactToMeState( - **state, chat_history=history, answer=result["answer"], ) diff --git a/src/agent/tasks/unsafe_answer.py b/src/agent/tasks/unsafe_answer.py index ff83fe6..7bfd8ec 100644 --- a/src/agent/tasks/unsafe_answer.py +++ b/src/agent/tasks/unsafe_answer.py @@ -18,17 +18,15 @@ def create_unsafe_answer_generator(llm: BaseChatModel) -> Runnable: You have advanced training in scientific ethics, dual-use research concerns, and responsible AI use. -You will receive three inputs: +You will receive two inputs: 1. The user's question. 2. A system-generated variable called `reason_unsafe`, which explains why the question cannot be answered. -3. The user's preferred language (as a language code or name). Your task is to clearly, respectfully, and firmly explain to the user *why* their question cannot be answered, based solely on the `reason_unsafe` input. Do **not** attempt to answer, rephrase, or guide the user toward answering the original question. You must: -- Respond in the user’s preferred language. - Politely explain the refusal, grounded in the `reason_unsafe`. -- Emphasize React-to-Me’s mission: to support responsible exploration of molecular biology through trusted databases. +- Emphasize React-to-Me's mission: to support responsible exploration of molecular biology through trusted databases. - Suggest examples of appropriate topics (e.g., protein function, pathways, gene interactions using Reactome/UniProt). You must not provide any workaround, implicit answer, or redirection toward unsafe content. @@ -40,7 +38,7 @@ def create_unsafe_answer_generator(llm: BaseChatModel) -> Runnable: ("system", system_prompt), ( "user", - "Language:{language}\n\nQuestion:{user_input}\n\n Reason for unsafe or out of scope: {reason_unsafe}", + "Question:{user_input}\n\n Reason for unsafe or out of scope: {reason_unsafe}", ), ] ) diff --git a/src/tools/preprocessing/state.py b/src/tools/preprocessing/state.py index 430a554..0e9910d 100644 --- a/src/tools/preprocessing/state.py +++ b/src/tools/preprocessing/state.py @@ -14,5 +14,3 @@ class PreprocessingState(TypedDict, total=False): rephrased_input: str safety: str reason_unsafe: str - expanded_queries: list[str] - detected_language: str From dd14c8a0e6d031aafacb406a47f6245ac41b5821 Mon Sep 17 00:00:00 2001 From: Greg Hogue Date: Fri, 16 Jan 2026 15:46:01 -0500 Subject: [PATCH 10/12] clean up safety changes --- src/agent/profiles/base.py | 51 +++++++++++----------- src/agent/profiles/cross_database.py | 18 +++++++- src/agent/profiles/react_to_me.py | 57 +++++++------------------ src/agent/tasks/safety_checker.py | 45 ++++++++------------ src/agent/tasks/unsafe_answer.py | 40 ++++++++---------- src/tools/preprocessing/__init__.py | 8 ---- src/tools/preprocessing/state.py | 16 ------- src/tools/preprocessing/workflow.py | 63 ---------------------------- 8 files changed, 93 insertions(+), 205 deletions(-) delete mode 100644 src/tools/preprocessing/__init__.py delete mode 100644 src/tools/preprocessing/state.py delete mode 100644 src/tools/preprocessing/workflow.py diff --git a/src/agent/profiles/base.py b/src/agent/profiles/base.py index fd7ec68..f594435 100644 --- a/src/agent/profiles/base.py +++ b/src/agent/profiles/base.py @@ -1,4 +1,4 @@ -from typing import Annotated, Literal, TypedDict +from typing import Annotated, TypedDict from langchain_core.embeddings import Embeddings from langchain_core.language_models.chat_models import BaseChatModel @@ -6,14 +6,10 @@ from langchain_core.runnables import Runnable, RunnableConfig from langgraph.graph.message import add_messages +from agent.tasks.rephrase import create_rephrase_chain +from agent.tasks.safety_checker import create_safety_checker from tools.external_search.state import SearchState, WebSearchResult from tools.external_search.workflow import create_search_workflow -from tools.preprocessing.state import PreprocessingState -from tools.preprocessing.workflow import create_preprocessing_workflow - -SAFETY_SAFE: Literal["true"] = "true" -SAFETY_UNSAFE: Literal["false"] = "false" -DEFAULT_LANGUAGE: str = "English" class AdditionalContent(TypedDict, total=False): @@ -25,48 +21,52 @@ class InputState(TypedDict, total=False): class OutputState(TypedDict, total=False): - answer: str # LLM response streamed to the user + answer: str # primary LLM response that is streamed to the user additional_content: AdditionalContent # sends on graph completion class BaseState(InputState, OutputState, total=False): - rephrased_input: ( - str # contextualized, LLM-generated standalone query from user input - ) + rephrased_input: str # LLM-generated query from user input chat_history: Annotated[list[BaseMessage], add_messages] safety: str # LLM-assessed safety level of the user input reason_unsafe: str class BaseGraphBuilder: + # NOTE: Anything that is common to all graph builders goes here def __init__( self, llm: BaseChatModel, embedding: Embeddings, ) -> None: - self.preprocessing_workflow: Runnable = create_preprocessing_workflow(llm) + self.rephrase_chain: Runnable = create_rephrase_chain(llm) + self.safety_checker: Runnable = create_safety_checker(llm) self.search_workflow: Runnable = create_search_workflow(llm) async def preprocess(self, state: BaseState, config: RunnableConfig) -> BaseState: - result: PreprocessingState = await self.preprocessing_workflow.ainvoke( - PreprocessingState( - user_input=state["user_input"], - chat_history=state.get("chat_history", []), - ), + rephrased_input: str = await self.rephrase_chain.ainvoke( + { + "user_input": state["user_input"], + "chat_history": state.get("chat_history", []), + }, + config, + ) + safety_check: BaseState = await self.safety_checker.ainvoke( + {"rephrased_input": rephrased_input}, config, ) - mapped_state = BaseState( - rephrased_input=result.get("rephrased_input", ""), - safety=result.get("safety", SAFETY_SAFE), - reason_unsafe=result.get("reason_unsafe", ""), + return BaseState( + rephrased_input=rephrased_input, + safety=safety_check["safety"].lower(), + reason_unsafe=safety_check["reason_unsafe"], ) - return BaseState(**state, **mapped_state) async def postprocess(self, state: BaseState, config: RunnableConfig) -> BaseState: search_results: list[WebSearchResult] = [] - if state.get("safety") == SAFETY_SAFE and config.get("configurable", {}).get( - "enable_postprocess", False + if ( + config["configurable"].get("enable_postprocess", False) + and state["safety"] == "true" ): result: SearchState = await self.search_workflow.ainvoke( SearchState( @@ -77,6 +77,5 @@ async def postprocess(self, state: BaseState, config: RunnableConfig) -> BaseSta ) search_results = result["search_results"] return BaseState( - **state, - additional_content=AdditionalContent(search_results=search_results), + additional_content=AdditionalContent(search_results=search_results) ) diff --git a/src/agent/profiles/cross_database.py b/src/agent/profiles/cross_database.py index 8a725f1..74ef26c 100644 --- a/src/agent/profiles/cross_database.py +++ b/src/agent/profiles/cross_database.py @@ -15,12 +15,16 @@ create_uniprot_rewriter_w_reactome from agent.tasks.cross_database.summarize_reactome_uniprot import \ create_reactome_uniprot_summarizer +from agent.tasks.detect_language import create_language_detector from agent.tasks.safety_checker import SafetyCheck, create_safety_checker from retrievers.reactome.rag import create_reactome_rag from retrievers.uniprot.rag import create_uniprot_rag class CrossDatabaseState(BaseState): + safety: str # LLM-assessed safety level of the user input + query_language: str # language of the user input + reactome_query: str # LLM-generated query for Reactome reactome_answer: str # LLM-generated answer from Reactome reactome_completeness: str # LLM-assessed completeness of the Reactome answer @@ -44,6 +48,7 @@ def __init__( self.safety_checker = create_safety_checker(llm) self.completeness_checker = create_completeness_grader(llm) + self.detect_language = create_language_detector(llm) self.write_reactome_query = create_reactome_rewriter_w_uniprot(llm) self.write_uniprot_query = create_uniprot_rewriter_w_reactome(llm) self.summarize_final_answer = create_reactome_uniprot_summarizer( @@ -55,6 +60,7 @@ def __init__( # Set up nodes state_graph.add_node("check_question_safety", self.check_question_safety) state_graph.add_node("preprocess_question", self.preprocess) + state_graph.add_node("identify_query_language", self.identify_query_language) state_graph.add_node("conduct_research", self.conduct_research) state_graph.add_node("generate_reactome_answer", self.generate_reactome_answer) state_graph.add_node("rewrite_reactome_query", self.rewrite_reactome_query) @@ -68,6 +74,7 @@ def __init__( state_graph.add_node("postprocess", self.postprocess) # Set up edges state_graph.set_entry_point("preprocess_question") + state_graph.add_edge("preprocess_question", "identify_query_language") state_graph.add_edge("preprocess_question", "check_question_safety") state_graph.add_conditional_edges( "check_question_safety", @@ -105,7 +112,7 @@ async def check_question_safety( config, ) if result.binary_score == "No": - inappropriate_input = f"This is the user's question and it is NOT appropriate for you to answer: {state["user_input"]}. \n\n explain that you are unable to answer the question but you can answer questions about topics related to the Reactome Pathway Knowledgebase or UniProt Knowledgebase." + inappropriate_input = f"This is the user's question and it is NOT appropriate for you to answer: {state["user_input"]}. \n\n explain that you are unable to answer the question but you can answer questions about topics related to the Reactome Pathway Knowledgebase or UniProt Knowledgebas." return CrossDatabaseState( safety=result.binary_score, user_input=inappropriate_input, @@ -123,6 +130,14 @@ async def proceed_with_research( else: return "Finish" + async def identify_query_language( + self, state: CrossDatabaseState, config: RunnableConfig + ) -> CrossDatabaseState: + query_language: str = await self.detect_language.ainvoke( + {"user_input": state["user_input"]}, config + ) + return CrossDatabaseState(query_language=query_language) + async def conduct_research( self, state: CrossDatabaseState, config: RunnableConfig ) -> CrossDatabaseState: @@ -241,6 +256,7 @@ async def generate_final_response( final_response: str = await self.summarize_final_answer.ainvoke( { "input": state["rephrased_input"], + "query_language": state["query_language"], "reactome_answer": state["reactome_answer"], "uniprot_answer": state["uniprot_answer"], }, diff --git a/src/agent/profiles/react_to_me.py b/src/agent/profiles/react_to_me.py index 54b9bcb..bd182ec 100644 --- a/src/agent/profiles/react_to_me.py +++ b/src/agent/profiles/react_to_me.py @@ -6,8 +6,7 @@ from langchain_core.runnables import Runnable, RunnableConfig from langgraph.graph.state import StateGraph -from agent.profiles.base import (SAFETY_SAFE, SAFETY_UNSAFE, BaseGraphBuilder, - BaseState) +from agent.profiles.base import BaseGraphBuilder, BaseState from agent.tasks.unsafe_answer import create_unsafe_answer_generator from retrievers.reactome.rag import create_reactome_rag @@ -51,44 +50,26 @@ def __init__( self.uncompiled_graph: StateGraph = state_graph async def proceed_with_research( - self, state: ReactToMeState + self, state: BaseState ) -> Literal["Continue", "Finish"]: - return "Continue" if state.get("safety") == SAFETY_SAFE else "Finish" + return "Continue" if state["safety"] == "true" else "Finish" async def generate_unsafe_response( self, state: ReactToMeState, config: RunnableConfig ) -> ReactToMeState: - final_answer_message = await self.unsafe_answer_generator.ainvoke( + answer: str = await self.unsafe_answer_generator.ainvoke( { - "user_input": state.get("rephrased_input", state["user_input"]), - "reason_unsafe": state.get("reason_unsafe", ""), + "user_input": state["rephrased_input"], + "reason_unsafe": state["reason_unsafe"], }, config, ) - - final_answer = ( - final_answer_message.content - if hasattr(final_answer_message, "content") - else str(final_answer_message) - ) - - history = list(state.get("chat_history", [])) - history.extend( - [ - HumanMessage(state["user_input"]), - ( - final_answer_message - if hasattr(final_answer_message, "content") - else AIMessage(final_answer) - ), - ] - ) - return ReactToMeState( - chat_history=history, - answer=final_answer, - safety=SAFETY_UNSAFE, - additional_content={"search_results": []}, + chat_history=[ + HumanMessage(state["user_input"]), + AIMessage(answer), + ], + answer=answer, ) async def call_model( @@ -97,23 +78,17 @@ async def call_model( result: dict[str, Any] = await self.reactome_rag.ainvoke( { "input": state["rephrased_input"], - "chat_history": ( - state.get("chat_history") - if state.get("chat_history") - else [HumanMessage(state["user_input"])] + "chat_history": state.get( + "chat_history", [HumanMessage(state["user_input"])] ), }, config, ) - history = list(state.get("chat_history", [])) - history.extend( - [ + return ReactToMeState( + chat_history=[ HumanMessage(state["user_input"]), AIMessage(result["answer"]), - ] - ) - return ReactToMeState( - chat_history=history, + ], answer=result["answer"], ) diff --git a/src/agent/tasks/safety_checker.py b/src/agent/tasks/safety_checker.py index d5af300..b997731 100644 --- a/src/agent/tasks/safety_checker.py +++ b/src/agent/tasks/safety_checker.py @@ -3,27 +3,7 @@ from langchain_core.runnables import Runnable from pydantic import BaseModel, Field - -class SafetyCheck(BaseModel): - safety: str = Field( - description="Indicates whether the question is appropriate and related to molecular biology. Expected values: 'true' or 'false'." - ) - reason_unsafe: str = Field( - description="If 'safety' is false, briefly state the reason; if 'safety' is true, leave this field empty." - ) - - -def create_safety_checker(llm: BaseChatModel) -> Runnable: - """ - Create a safety checker chain. - - Args: - llm: Language model to use - - Returns: - Runnable that takes rephrased_input and returns SafetyCheck - """ - safety_check_message = """ +safety_check_message = """ You are an expert scientific assistant. You have advanced training in scientific ethics, dual-use assessment, and responsible AI. Tasks: @@ -49,11 +29,22 @@ def create_safety_checker(llm: BaseChatModel) -> Runnable: "reason_unsafe": "..." // If 'safety' is false, briefly state the reason """ - prompt = ChatPromptTemplate.from_messages( - [ - ("system", safety_check_message), - ("human", "User question: \n\n {rephrased_input}"), - ] +safety_check_prompt = ChatPromptTemplate.from_messages( + [ + ("system", safety_check_message), + ("human", "User question: \n\n {rephrased_input}"), + ] +) + + +class SafetyCheck(BaseModel): + safety: str = Field( + description="Indicates whether the question is appropriate and related to molecular biology. Expected values: 'true' or 'false'." + ) + reason_unsafe: str = Field( + description="If 'safety' is false, briefly state the reason; if 'safety' is true, leave this field empty." ) - return prompt | llm.with_structured_output(SafetyCheck) + +def create_safety_checker(llm: BaseChatModel) -> Runnable: + return safety_check_prompt | llm.with_structured_output(SafetyCheck) diff --git a/src/agent/tasks/unsafe_answer.py b/src/agent/tasks/unsafe_answer.py index 7bfd8ec..42d4ae6 100644 --- a/src/agent/tasks/unsafe_answer.py +++ b/src/agent/tasks/unsafe_answer.py @@ -1,19 +1,9 @@ from langchain_core.language_models.chat_models import BaseChatModel +from langchain_core.output_parsers import StrOutputParser from langchain_core.prompts import ChatPromptTemplate from langchain_core.runnables import Runnable - -def create_unsafe_answer_generator(llm: BaseChatModel) -> Runnable: - """ - Create an unsafe answer generator chain. - - Args: - llm: Language model to use. - - Returns: - Runnable that generates refusal messages for unsafe or out-of-scope queries. - """ - system_prompt = """ +unsafe_answer_message = """ You are an expert scientific assistant operating under the React-to-Me platform. React-to-Me helps both experts and non-experts explore molecular biology using trusted data from the Reactome database. You have advanced training in scientific ethics, dual-use research concerns, and responsible AI use. @@ -31,16 +21,20 @@ def create_unsafe_answer_generator(llm: BaseChatModel) -> Runnable: You must not provide any workaround, implicit answer, or redirection toward unsafe content. """ - streaming_llm = llm.model_copy(update={"streaming": True}) - prompt = ChatPromptTemplate.from_messages( - [ - ("system", system_prompt), - ( - "user", - "Question:{user_input}\n\n Reason for unsafe or out of scope: {reason_unsafe}", - ), - ] - ) +unsafe_answer_prompt = ChatPromptTemplate.from_messages( + [ + ("system", unsafe_answer_message), + ( + "user", + "Question:{user_input}\n\n Reason for unsafe or out of scope: {reason_unsafe}", + ), + ] +) + - return prompt | streaming_llm +def create_unsafe_answer_generator(llm: BaseChatModel) -> Runnable: + streaming_llm = llm.model_copy(update={"streaming": True}) + return (unsafe_answer_prompt | streaming_llm | StrOutputParser()).with_config( + run_name="unsafe_answer" + ) diff --git a/src/tools/preprocessing/__init__.py b/src/tools/preprocessing/__init__.py deleted file mode 100644 index 01a26cf..0000000 --- a/src/tools/preprocessing/__init__.py +++ /dev/null @@ -1,8 +0,0 @@ -""" -Preprocessing utilities with reusable workflows and state definitions. -""" - -from tools.preprocessing.state import PreprocessingState -from tools.preprocessing.workflow import create_preprocessing_workflow - -__all__ = ["PreprocessingState", "create_preprocessing_workflow"] diff --git a/src/tools/preprocessing/state.py b/src/tools/preprocessing/state.py deleted file mode 100644 index 0e9910d..0000000 --- a/src/tools/preprocessing/state.py +++ /dev/null @@ -1,16 +0,0 @@ -from typing import TypedDict - -from langchain_core.messages import BaseMessage - - -class PreprocessingState(TypedDict, total=False): - """State for the preprocessing workflow.""" - - # Inputs - user_input: str - chat_history: list[BaseMessage] - - # Task outputs - rephrased_input: str - safety: str - reason_unsafe: str diff --git a/src/tools/preprocessing/workflow.py b/src/tools/preprocessing/workflow.py deleted file mode 100644 index fd5c0d0..0000000 --- a/src/tools/preprocessing/workflow.py +++ /dev/null @@ -1,63 +0,0 @@ -from typing import Any, Callable - -from langchain_core.language_models.chat_models import BaseChatModel -from langchain_core.runnables import Runnable, RunnableConfig -from langgraph.graph import StateGraph -from langgraph.graph.state import CompiledStateGraph -from langgraph.utils.runnable import RunnableLike - -from agent.tasks.rephrase import create_rephrase_chain -from agent.tasks.safety_checker import create_safety_checker -from tools.preprocessing.state import PreprocessingState - - -def create_task_wrapper( - task: Runnable, - input_mapper: Callable[[PreprocessingState], dict[str, Any]], - output_mapper: Callable[[Any], PreprocessingState], -) -> RunnableLike: - """Wrap a runnable with state mappers.""" - - async def _wrapper( - state: PreprocessingState, config: RunnableConfig - ) -> PreprocessingState: - result = await task.ainvoke(input_mapper(state), config) - return output_mapper(result) - - return _wrapper - - -def create_preprocessing_workflow(llm: BaseChatModel) -> CompiledStateGraph: - """Create the preprocessing workflow with rephrasing and safety checking.""" - - tasks = { - "rephrase_query": ( - create_rephrase_chain(llm), - lambda state: { - "user_input": state["user_input"], - "chat_history": state.get("chat_history", []), - }, - lambda result: PreprocessingState(rephrased_input=result), - ), - "safety_check": ( - create_safety_checker(llm), - lambda state: {"rephrased_input": state["rephrased_input"]}, - lambda result: PreprocessingState( - safety=result.safety.lower(), - reason_unsafe=result.reason_unsafe, - ), - ), - } - - workflow = StateGraph(PreprocessingState) - - for node_name, (task, input_mapper, output_mapper) in tasks.items(): - workflow.add_node( - node_name, create_task_wrapper(task, input_mapper, output_mapper) - ) - - workflow.set_entry_point("rephrase_query") - workflow.add_edge("rephrase_query", "safety_check") - workflow.set_finish_point("safety_check") - - return workflow.compile() From 5808ae6ecedf8d52e64c2e3fb1663bd3e4afc4a5 Mon Sep 17 00:00:00 2001 From: Greg Hogue Date: Fri, 16 Jan 2026 16:13:54 -0500 Subject: [PATCH 11/12] dedup safety routing --- src/agent/profiles/base.py | 5 ++++- src/agent/profiles/cross_database.py | 9 --------- src/agent/profiles/react_to_me.py | 7 +------ 3 files changed, 5 insertions(+), 16 deletions(-) diff --git a/src/agent/profiles/base.py b/src/agent/profiles/base.py index f594435..f62b6c6 100644 --- a/src/agent/profiles/base.py +++ b/src/agent/profiles/base.py @@ -1,4 +1,4 @@ -from typing import Annotated, TypedDict +from typing import Annotated, Literal, TypedDict from langchain_core.embeddings import Embeddings from langchain_core.language_models.chat_models import BaseChatModel @@ -62,6 +62,9 @@ async def preprocess(self, state: BaseState, config: RunnableConfig) -> BaseStat reason_unsafe=safety_check["reason_unsafe"], ) + def proceed_with_research(self, state: BaseState) -> Literal["Continue", "Finish"]: + return "Continue" if state["safety"] == "true" else "Finish" + async def postprocess(self, state: BaseState, config: RunnableConfig) -> BaseState: search_results: list[WebSearchResult] = [] if ( diff --git a/src/agent/profiles/cross_database.py b/src/agent/profiles/cross_database.py index 74ef26c..f96b531 100644 --- a/src/agent/profiles/cross_database.py +++ b/src/agent/profiles/cross_database.py @@ -22,7 +22,6 @@ class CrossDatabaseState(BaseState): - safety: str # LLM-assessed safety level of the user input query_language: str # language of the user input reactome_query: str # LLM-generated query for Reactome @@ -122,14 +121,6 @@ async def check_question_safety( else: return CrossDatabaseState(safety=result.binary_score) - async def proceed_with_research( - self, state: CrossDatabaseState - ) -> Literal["Continue", "Finish"]: - if state["safety"] == "Yes": - return "Continue" - else: - return "Finish" - async def identify_query_language( self, state: CrossDatabaseState, config: RunnableConfig ) -> CrossDatabaseState: diff --git a/src/agent/profiles/react_to_me.py b/src/agent/profiles/react_to_me.py index bd182ec..76364e4 100644 --- a/src/agent/profiles/react_to_me.py +++ b/src/agent/profiles/react_to_me.py @@ -1,4 +1,4 @@ -from typing import Any, Literal +from typing import Any from langchain_core.embeddings import Embeddings from langchain_core.language_models.chat_models import BaseChatModel @@ -49,11 +49,6 @@ def __init__( self.uncompiled_graph: StateGraph = state_graph - async def proceed_with_research( - self, state: BaseState - ) -> Literal["Continue", "Finish"]: - return "Continue" if state["safety"] == "true" else "Finish" - async def generate_unsafe_response( self, state: ReactToMeState, config: RunnableConfig ) -> ReactToMeState: From f8f3020dfc8e8fdd9422e9dd0067a91d873355f8 Mon Sep 17 00:00:00 2001 From: Greg Hogue Date: Fri, 16 Jan 2026 17:03:21 -0500 Subject: [PATCH 12/12] update macos x64 image for GitHub Actions --- .github/workflows/ci.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 6908dab..c806cff 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -35,7 +35,7 @@ jobs: runs-on: ${{ matrix.os }} strategy: matrix: - os: [ubuntu-latest, macos-13] + os: [ubuntu-latest, macos-15-intel] steps: - uses: actions/checkout@v4