Skip to content
44 changes: 30 additions & 14 deletions src/agent/profiles/base.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,19 @@
from typing import Annotated, TypedDict
from typing import Annotated, Literal, TypedDict

from langchain_core.embeddings import Embeddings
from langchain_core.language_models.chat_models import BaseChatModel
from langchain_core.messages import BaseMessage
from langchain_core.runnables import Runnable, RunnableConfig
from langgraph.graph.message import add_messages

from agent.tasks.rephrase import create_rephrase_chain
from tools.external_search.state import SearchState, WebSearchResult
from tools.external_search.workflow import create_search_workflow
from tools.preprocessing.state import PreprocessingState
from tools.preprocessing.workflow import create_preprocessing_workflow

SAFETY_SAFE: Literal["true"] = "true"
SAFETY_UNSAFE: Literal["false"] = "false"
DEFAULT_LANGUAGE: str = "English"


class AdditionalContent(TypedDict, total=False):
Expand All @@ -20,39 +25,49 @@ class InputState(TypedDict, total=False):


class OutputState(TypedDict, total=False):
answer: str # primary LLM response that is streamed to the user
answer: str # LLM response streamed to the user
additional_content: AdditionalContent # sends on graph completion


class BaseState(InputState, OutputState, total=False):
rephrased_input: str # LLM-generated query from user input
rephrased_input: (
str # contextualized, LLM-generated standalone query from user input
)
chat_history: Annotated[list[BaseMessage], add_messages]
safety: str # LLM-assessed safety level of the user input
reason_unsafe: str


class BaseGraphBuilder:
# NOTE: Anything that is common to all graph builders goes here

def __init__(
self,
llm: BaseChatModel,
embedding: Embeddings,
) -> None:
self.rephrase_chain: Runnable = create_rephrase_chain(llm)
self.preprocessing_workflow: Runnable = create_preprocessing_workflow(llm)
self.search_workflow: Runnable = create_search_workflow(llm)

async def preprocess(self, state: BaseState, config: RunnableConfig) -> BaseState:
rephrased_input: str = await self.rephrase_chain.ainvoke(
{
"user_input": state["user_input"],
"chat_history": state["chat_history"],
},
result: PreprocessingState = await self.preprocessing_workflow.ainvoke(
PreprocessingState(
user_input=state["user_input"],
chat_history=state.get("chat_history", []),
),
config,
)
return BaseState(rephrased_input=rephrased_input)
mapped_state = BaseState(
rephrased_input=result.get("rephrased_input", ""),
safety=result.get("safety", SAFETY_SAFE),
reason_unsafe=result.get("reason_unsafe", ""),
)
return BaseState(**state, **mapped_state)

async def postprocess(self, state: BaseState, config: RunnableConfig) -> BaseState:
search_results: list[WebSearchResult] = []
if config["configurable"]["enable_postprocess"]:
if state.get("safety") == SAFETY_SAFE and config.get("configurable", {}).get(
"enable_postprocess", False
):
result: SearchState = await self.search_workflow.ainvoke(
SearchState(
input=state["rephrased_input"],
Expand All @@ -62,5 +77,6 @@ async def postprocess(self, state: BaseState, config: RunnableConfig) -> BaseSta
)
search_results = result["search_results"]
return BaseState(
additional_content=AdditionalContent(search_results=search_results)
**state,
additional_content=AdditionalContent(search_results=search_results),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does adding **state in spots like this solve some issue with updating the state?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The BaseState(**state, …) merge isn’t there to work around a bug—it’s there because LangGraph hands us an evolving state dict, and we don’t want to drop any of the fields that preprocessing or earlier nodes already wrote (rephrased input, safety tag, chat history, etc.). Postprocess only needs to add additional_content, so merging **state with the new field preserves the existing state while layering on the search results. If we just returned BaseState(additional_content=...), we’d lose everything else that was already in the state and downstream nodes would break.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LangGraph implicitly updates the state using the returned dict without dropping omitted fields, so we shouldn't need to do this.

)
18 changes: 1 addition & 17 deletions src/agent/profiles/cross_database.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,16 +15,12 @@
create_uniprot_rewriter_w_reactome
from agent.tasks.cross_database.summarize_reactome_uniprot import \
create_reactome_uniprot_summarizer
from agent.tasks.detect_language import create_language_detector
from agent.tasks.safety_checker import SafetyCheck, create_safety_checker
from retrievers.reactome.rag import create_reactome_rag
from retrievers.uniprot.rag import create_uniprot_rag


class CrossDatabaseState(BaseState):
safety: str # LLM-assessed safety level of the user input
query_language: str # language of the user input

reactome_query: str # LLM-generated query for Reactome
reactome_answer: str # LLM-generated answer from Reactome
reactome_completeness: str # LLM-assessed completeness of the Reactome answer
Expand All @@ -48,7 +44,6 @@ def __init__(

self.safety_checker = create_safety_checker(llm)
self.completeness_checker = create_completeness_grader(llm)
self.detect_language = create_language_detector(llm)
self.write_reactome_query = create_reactome_rewriter_w_uniprot(llm)
self.write_uniprot_query = create_uniprot_rewriter_w_reactome(llm)
self.summarize_final_answer = create_reactome_uniprot_summarizer(
Expand All @@ -60,7 +55,6 @@ def __init__(
# Set up nodes
state_graph.add_node("check_question_safety", self.check_question_safety)
state_graph.add_node("preprocess_question", self.preprocess)
state_graph.add_node("identify_query_language", self.identify_query_language)
state_graph.add_node("conduct_research", self.conduct_research)
state_graph.add_node("generate_reactome_answer", self.generate_reactome_answer)
state_graph.add_node("rewrite_reactome_query", self.rewrite_reactome_query)
Expand All @@ -74,7 +68,6 @@ def __init__(
state_graph.add_node("postprocess", self.postprocess)
# Set up edges
state_graph.set_entry_point("preprocess_question")
state_graph.add_edge("preprocess_question", "identify_query_language")
state_graph.add_edge("preprocess_question", "check_question_safety")
state_graph.add_conditional_edges(
"check_question_safety",
Expand Down Expand Up @@ -112,7 +105,7 @@ async def check_question_safety(
config,
)
if result.binary_score == "No":
inappropriate_input = f"This is the user's question and it is NOT appropriate for you to answer: {state["user_input"]}. \n\n explain that you are unable to answer the question but you can answer questions about topics related to the Reactome Pathway Knowledgebase or UniProt Knowledgebas."
inappropriate_input = f"This is the user's question and it is NOT appropriate for you to answer: {state["user_input"]}. \n\n explain that you are unable to answer the question but you can answer questions about topics related to the Reactome Pathway Knowledgebase or UniProt Knowledgebase."
return CrossDatabaseState(
safety=result.binary_score,
user_input=inappropriate_input,
Expand All @@ -130,14 +123,6 @@ async def proceed_with_research(
else:
return "Finish"

async def identify_query_language(
self, state: CrossDatabaseState, config: RunnableConfig
) -> CrossDatabaseState:
query_language: str = await self.detect_language.ainvoke(
{"user_input": state["user_input"]}, config
)
return CrossDatabaseState(query_language=query_language)

async def conduct_research(
self, state: CrossDatabaseState, config: RunnableConfig
) -> CrossDatabaseState:
Expand Down Expand Up @@ -256,7 +241,6 @@ async def generate_final_response(
final_response: str = await self.summarize_final_answer.ainvoke(
{
"input": state["rephrased_input"],
"query_language": state["query_language"],
"reactome_answer": state["reactome_answer"],
"uniprot_answer": state["uniprot_answer"],
},
Expand Down
70 changes: 62 additions & 8 deletions src/agent/profiles/react_to_me.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
from typing import Any
from typing import Any, Literal

from langchain_core.embeddings import Embeddings
from langchain_core.language_models.chat_models import BaseChatModel
from langchain_core.messages import AIMessage, HumanMessage
from langchain_core.runnables import Runnable, RunnableConfig
from langgraph.graph.state import StateGraph

from agent.profiles.base import BaseGraphBuilder, BaseState
from agent.profiles.base import (SAFETY_SAFE, SAFETY_UNSAFE, BaseGraphBuilder,
BaseState)
from agent.tasks.unsafe_answer import create_unsafe_answer_generator
from retrievers.reactome.rag import create_reactome_rag


Expand All @@ -23,6 +25,7 @@ def __init__(
super().__init__(llm, embedding)

# Create runnables (tasks & tools)
self.unsafe_answer_generator = create_unsafe_answer_generator(llm)
self.reactome_rag: Runnable = create_reactome_rag(
llm, embedding, streaming=True
)
Expand All @@ -32,34 +35,85 @@ def __init__(
# Set up nodes
state_graph.add_node("preprocess", self.preprocess)
state_graph.add_node("model", self.call_model)
state_graph.add_node("generate_unsafe_response", self.generate_unsafe_response)
state_graph.add_node("postprocess", self.postprocess)
# Set up edges
state_graph.set_entry_point("preprocess")
state_graph.add_edge("preprocess", "model")
state_graph.add_conditional_edges(
"preprocess",
self.proceed_with_research,
{"Continue": "model", "Finish": "generate_unsafe_response"},
)
state_graph.add_edge("model", "postprocess")
state_graph.add_edge("generate_unsafe_response", "postprocess")
state_graph.set_finish_point("postprocess")

self.uncompiled_graph: StateGraph = state_graph

async def proceed_with_research(
self, state: ReactToMeState
) -> Literal["Continue", "Finish"]:
return "Continue" if state.get("safety") == SAFETY_SAFE else "Finish"

async def generate_unsafe_response(
self, state: ReactToMeState, config: RunnableConfig
) -> ReactToMeState:
final_answer_message = await self.unsafe_answer_generator.ainvoke(
{
"user_input": state.get("rephrased_input", state["user_input"]),
"reason_unsafe": state.get("reason_unsafe", ""),
},
config,
)

final_answer = (
final_answer_message.content
if hasattr(final_answer_message, "content")
else str(final_answer_message)
)

history = list(state.get("chat_history", []))
history.extend(
[
HumanMessage(state["user_input"]),
(
final_answer_message
if hasattr(final_answer_message, "content")
else AIMessage(final_answer)
),
]
)

return ReactToMeState(
chat_history=history,
answer=final_answer,
safety=SAFETY_UNSAFE,
additional_content={"search_results": []},
)

async def call_model(
self, state: ReactToMeState, config: RunnableConfig
) -> ReactToMeState:
result: dict[str, Any] = await self.reactome_rag.ainvoke(
{
"input": state["rephrased_input"],
"chat_history": (
state["chat_history"]
if state["chat_history"]
state.get("chat_history")
if state.get("chat_history")
else [HumanMessage(state["user_input"])]
),
},
config,
)
return ReactToMeState(
chat_history=[
history = list(state.get("chat_history", []))
history.extend(
[
HumanMessage(state["user_input"]),
AIMessage(result["answer"]),
],
]
)
return ReactToMeState(
chat_history=history,
answer=result["answer"],
)

Expand Down
67 changes: 48 additions & 19 deletions src/agent/tasks/safety_checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,28 +3,57 @@
from langchain_core.runnables import Runnable
from pydantic import BaseModel, Field

safety_check_message = """
You are an expert grader with a strong understanding of ethical guidelines and responsible AI usage.

Your task is to determine whether a user-submitted question is appropriate and not harmful.
Provide a binary output as either:
Yes: The question is appropriate and not harmful.
No: The question is harmful, offensive, inappropriate, or unethical.
"""

safety_check_prompt = ChatPromptTemplate.from_messages(
[
("system", safety_check_message),
("human", "User question: \n\n {input}"),
]
)


class SafetyCheck(BaseModel):
binary_score: str = Field(
description="Indicates whether the question is appropriate and related to molecular biology. Expected values: 'Yes' or 'No'."
safety: str = Field(
description="Indicates whether the question is appropriate and related to molecular biology. Expected values: 'true' or 'false'."
)
reason_unsafe: str = Field(
description="If 'safety' is false, briefly state the reason; if 'safety' is true, leave this field empty."
)


def create_safety_checker(llm: BaseChatModel) -> Runnable:
return safety_check_prompt | llm.with_structured_output(SafetyCheck)
"""
Create a safety checker chain.

Args:
llm: Language model to use

Returns:
Runnable that takes rephrased_input and returns SafetyCheck
"""
safety_check_message = """
You are an expert scientific assistant. You have advanced training in scientific ethics, dual-use assessment, and responsible AI.

Tasks:

1. Safety & Appropriateness Check
- Assess the question and determine if it is safe and appropriate for responsible AI use.
- The question and any plausible answer to it must NOT be harmful, offensive, unethical, or violate scientific or ethical guidelines.
- The question or its likely answers must NOT request, describe, or explore how to carry out harmful, illegal, dangerous, or dual-use activities (such as modifying pathogens, increasing infectivity, or other forms of misuse)—even if the scenario is hypothetical, fictional, or indirect.
- The question must NOT request personal, medical, legal, or other types of advice outside the scope of scientific knowledge.
- If a question is ambiguous or could plausibly be used for harm, treat it as unsafe.
- Treat hypothetical, fictional, or made-up scenarios with the same level of scrutiny as real-world questions.

2. Reactome Relevance Check
- Determine if the question is relevant to biology, life sciences, molecular biology, or related topics.
- Mark questions as not relevant if they are about unrelated topics (such as programming, math, history, trivia, etc.).

IMPORTANT:
- If the standalone question is unsafe or not relevant return "safety": "false".
- If the standalone question is both safe **and** relevant, return "safety": "true".

Return only a JSON object in the following format:
"safety": "true" or "false", // Use string, not boolean. Mark as "false" if unsafe OR not relevant.
"reason_unsafe": "..." // If 'safety' is false, briefly state the reason
"""

prompt = ChatPromptTemplate.from_messages(
[
("system", safety_check_message),
("human", "User question: \n\n {rephrased_input}"),
]
)

return prompt | llm.with_structured_output(SafetyCheck)
Loading
Loading