Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 59 additions & 21 deletions src/agent/profiles/base.py
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please exclude changes to code formatting & comments to unmodified existing code from the diff.

Original file line number Diff line number Diff line change
@@ -1,58 +1,93 @@
from typing import Annotated, TypedDict
from typing import Annotated, Literal, TypedDict

from langchain_core.embeddings import Embeddings
from langchain_core.language_models.chat_models import BaseChatModel
from langchain_core.messages import BaseMessage
from langchain_core.runnables import Runnable, RunnableConfig
from langgraph.graph.message import add_messages

from agent.tasks.rephrase import create_rephrase_chain
from tools.external_search.state import SearchState, WebSearchResult
from tools.external_search.workflow import create_search_workflow
from tools.preprocessing.state import PreprocessingState
from tools.preprocessing.workflow import create_preprocessing_workflow

# Constants
SAFETY_SAFE: Literal["true"] = "true"
SAFETY_UNSAFE: Literal["false"] = "false"
DEFAULT_LANGUAGE: str = "English"


class AdditionalContent(TypedDict, total=False):
"""Additional content sent on graph completion."""

search_results: list[WebSearchResult]


class InputState(TypedDict, total=False):
user_input: str # User input text
"""Input state for user queries."""

user_input: str


class OutputState(TypedDict, total=False):
answer: str # primary LLM response that is streamed to the user
additional_content: AdditionalContent # sends on graph completion
"""Output state for responses."""

answer: str
additional_content: AdditionalContent


class BaseState(InputState, OutputState, total=False):
rephrased_input: str # LLM-generated query from user input
"""Base state containing all common fields for agent workflows."""

rephrased_input: str
chat_history: Annotated[list[BaseMessage], add_messages]

# Preprocessing results
safety: str
reason_unsafe: str
expanded_queries: list[str]
detected_language: str


class BaseGraphBuilder:
# NOTE: Anything that is common to all graph builders goes here

def __init__(
self,
llm: BaseChatModel,
embedding: Embeddings,
) -> None:
self.rephrase_chain: Runnable = create_rephrase_chain(llm)
"""Base class for all graph builders with common preprocessing and postprocessing."""

def __init__(self, llm: BaseChatModel, embedding: Embeddings) -> None:
"""Initialize with LLM and embedding models."""
self.preprocessing_workflow: Runnable = create_preprocessing_workflow(llm)
self.search_workflow: Runnable = create_search_workflow(llm)

async def preprocess(self, state: BaseState, config: RunnableConfig) -> BaseState:
rephrased_input: str = await self.rephrase_chain.ainvoke(
{
"user_input": state["user_input"],
"chat_history": state["chat_history"],
},
"""Run the complete preprocessing workflow and map results to state."""
result: PreprocessingState = await self.preprocessing_workflow.ainvoke(
PreprocessingState(
user_input=state["user_input"],
chat_history=state["chat_history"],
),
config,
)
return BaseState(rephrased_input=rephrased_input)

return self._map_preprocessing_result(result)

def _map_preprocessing_result(self, result: PreprocessingState) -> BaseState:
"""Map preprocessing results to BaseState with defaults."""
return BaseState(
rephrased_input=result["rephrased_input"],
safety=result.get("safety", SAFETY_SAFE),
reason_unsafe=result.get("reason_unsafe", ""),
expanded_queries=result.get("expanded_queries", []),
detected_language=result.get("detected_language", DEFAULT_LANGUAGE),
)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove _map_preprocessing_result() and just return this in preprocess()


async def postprocess(self, state: BaseState, config: RunnableConfig) -> BaseState:
"""Postprocess that preserves existing state and conditionally adds search results."""
search_results: list[WebSearchResult] = []
if config["configurable"]["enable_postprocess"]:

# Only run external search for safe questions
if (
state.get("safety") == SAFETY_SAFE
and config["configurable"]["enable_postprocess"]
):
result: SearchState = await self.search_workflow.ainvoke(
SearchState(
input=state["rephrased_input"],
Expand All @@ -61,6 +96,9 @@ async def postprocess(self, state: BaseState, config: RunnableConfig) -> BaseSta
config=RunnableConfig(callbacks=config["callbacks"]),
)
search_results = result["search_results"]

# Create new state with updated additional_content
return BaseState(
**state, # Copy existing state
additional_content=AdditionalContent(search_results=search_results)
)
106 changes: 88 additions & 18 deletions src/agent/profiles/react_to_me.py
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please exclude changes to code formatting & comments to unmodified existing code from the diff.

Original file line number Diff line number Diff line change
@@ -1,52 +1,123 @@
from typing import Any
from typing import Any, Literal

from langchain_core.embeddings import Embeddings
from langchain_core.language_models.chat_models import BaseChatModel
from langchain_core.messages import AIMessage, HumanMessage
from langchain_core.runnables import Runnable, RunnableConfig
from langchain_openai import ChatOpenAI
from langgraph.graph.state import StateGraph

from agent.profiles.base import BaseGraphBuilder, BaseState
from agent.profiles.base import (SAFETY_SAFE, SAFETY_UNSAFE, BaseGraphBuilder,
BaseState)
from agent.tasks.final_answer_generation.unsafe_question import \
create_unsafe_answer_generator
from retrievers.reactome.rag import create_reactome_rag


class ReactToMeState(BaseState):
"""ReactToMe state extends BaseState with all preprocessing results."""

pass


class ReactToMeGraphBuilder(BaseGraphBuilder):
def __init__(
self,
llm: BaseChatModel,
embedding: Embeddings,
) -> None:
"""Graph builder for ReactToMe profile with Reactome-specific functionality."""

def __init__(self, llm: BaseChatModel, embedding: Embeddings) -> None:
"""Initialize ReactToMe graph builder with required components."""
super().__init__(llm, embedding)

# Create runnables (tasks & tools)
# Create a streaming LLM instance only for final answer generation
streaming_llm = ChatOpenAI(
model=llm.model_name if hasattr(llm, "model_name") else "gpt-4o-mini",
temperature=0.0,
streaming=True,
)

self.unsafe_answer_generator = create_unsafe_answer_generator(streaming_llm)
self.reactome_rag: Runnable = create_reactome_rag(
llm, embedding, streaming=True
streaming_llm, embedding, streaming=True
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why create streaming_llm like this when we already have this?:

llm = llm.model_copy(update={"streaming": True})

)

# Create graph
self.uncompiled_graph: StateGraph = self._build_workflow()

def _build_workflow(self) -> StateGraph:
"""Build and configure the ReactToMe workflow graph."""
state_graph = StateGraph(ReactToMeState)
# Set up nodes

# Add nodes
state_graph.add_node("preprocess", self.preprocess)
state_graph.add_node("model", self.call_model)
state_graph.add_node("generate_unsafe_response", self.generate_unsafe_response)
state_graph.add_node("postprocess", self.postprocess)
# Set up edges

# Add edges
state_graph.set_entry_point("preprocess")
state_graph.add_edge("preprocess", "model")
state_graph.add_conditional_edges(
"preprocess",
self.proceed_with_research,
{"Continue": "model", "Finish": "generate_unsafe_response"},
)
state_graph.add_edge("model", "postprocess")
state_graph.add_edge("generate_unsafe_response", "postprocess")
state_graph.set_finish_point("postprocess")

self.uncompiled_graph: StateGraph = state_graph
return state_graph

async def preprocess(
self, state: ReactToMeState, config: RunnableConfig
) -> ReactToMeState:
"""Run preprocessing workflow."""
result = await super().preprocess(state, config)
return ReactToMeState(**result)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No need to define this preprocess() if it's just going to run the one from the superclass (BaseGraphBuilder)


async def proceed_with_research(
self, state: ReactToMeState
) -> Literal["Continue", "Finish"]:
"""Determine whether to proceed with research based on safety check."""
return "Continue" if state["safety"] == SAFETY_SAFE else "Finish"

async def generate_unsafe_response(
self, state: ReactToMeState, config: RunnableConfig
) -> ReactToMeState:
"""Generate appropriate refusal response for unsafe queries."""
final_answer_message = await self.unsafe_answer_generator.ainvoke(
{
"language": state["detected_language"],
"user_input": state["rephrased_input"],
"reason_unsafe": state["reason_unsafe"],
},
config,
)

final_answer = (
final_answer_message.content
if hasattr(final_answer_message, "content")
else str(final_answer_message)
)

return ReactToMeState(
chat_history=[
HumanMessage(state["user_input"]),
(
final_answer_message
if hasattr(final_answer_message, "content")
else AIMessage(final_answer)
),
],
answer=final_answer,
safety=SAFETY_UNSAFE,
additional_content={"search_results": []},
)

async def call_model(
self, state: ReactToMeState, config: RunnableConfig
) -> ReactToMeState:
"""Generate response using Reactome RAG for safe queries."""
result: dict[str, Any] = await self.reactome_rag.ainvoke(
{
"input": state["rephrased_input"],
"expanded_queries": state.get("expanded_queries", []),
"chat_history": (
state["chat_history"]
if state["chat_history"]
Expand All @@ -55,6 +126,7 @@ async def call_model(
},
config,
)

return ReactToMeState(
chat_history=[
HumanMessage(state["user_input"]),
Expand All @@ -64,8 +136,6 @@ async def call_model(
)


def create_reactome_graph(
llm: BaseChatModel,
embedding: Embeddings,
) -> StateGraph:
def create_reactome_graph(llm: BaseChatModel, embedding: Embeddings) -> StateGraph:
"""Create and return the ReactToMe workflow graph."""
return ReactToMeGraphBuilder(llm, embedding).uncompiled_graph
47 changes: 47 additions & 0 deletions src/agent/tasks/final_answer_generation/unsafe_question.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
from langchain_core.language_models.chat_models import BaseChatModel
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.runnables import Runnable


# unsafe or out of scope answer generator
def create_unsafe_answer_generator(llm: BaseChatModel) -> Runnable:
"""
Create an unsafe answer generator chain.
Args:
llm: Language model to use
Returns:
Runnable that takes language, user_input, reactome_context, uniprot_context, chat_history
"""
system_prompt = """
You are an expert scientific assistant operating under the React-to-Me platform. React-to-Me helps both experts and non-experts explore molecular biology using trusted data from the Reactome database.
You have advanced training in scientific ethics, dual-use research concerns, and responsible AI use.
You will receive three inputs:
1. The user's question.
2. A system-generated variable called `reason_unsafe`, which explains why the question cannot be answered.
3. The user's preferred language (as a language code or name).
Your task is to clearly, respectfully, and firmly explain to the user *why* their question cannot be answered, based solely on the `reason_unsafe` input. Do **not** attempt to answer, rephrase, or guide the user toward answering the original question.
You must:
- Respond in the user’s preferred language.
- Politely explain the refusal, grounded in the `reason_unsafe`.
- Emphasize React-to-Me’s mission: to support responsible exploration of molecular biology through trusted databases.
- Suggest examples of appropriate topics (e.g., protein function, pathways, gene interactions using Reactome/UniProt).
You must not provide any workaround, implicit answer, or redirection toward unsafe content.
"""
prompt = ChatPromptTemplate.from_messages(
[
("system", system_prompt),
(
"user",
"Language:{language}\n\nQuestion:{user_input}\n\n Reason for unsafe or out of scope: {reason_unsafe}",
),
]
)

return prompt | llm
60 changes: 60 additions & 0 deletions src/agent/tasks/query_expansion.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
import json
from typing import List
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

importing List type is deprecated. Current Python just uses list directly.


from langchain.prompts import ChatPromptTemplate
from langchain_core.language_models.chat_models import BaseChatModel
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import Runnable, RunnableLambda


def QueryExpansionParser(output: str) -> List[str]:
"""Parse JSON array output from LLM."""
try:
return json.loads(output)
except json.JSONDecodeError:
raise ValueError("LLM output was not valid JSON. Output:\n" + output)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will an LLM emitting invalid JSON crash the chatbot here?



def create_query_expander(llm: BaseChatModel) -> Runnable:
"""
Create a query expansion chain that generates 4 alternative queries.
Args:
llm: Language model to use
Returns:
Runnable that takes standalone_query and returns List[str]
"""
system_prompt = """
You are a biomedical question expansion engine for information retrieval over the Reactome biological pathway database.
Given a single user question, generate **exactly 4** alternate standalone questions. These should be:
- Semantically related to the original question.
- Lexically diverse to improve retrieval via vector search and RAG-fusion.
- Biologically enriched with inferred or associated details.
Your goal is to improve recall of relevant documents by expanding the original query using:
- Synonymous gene/protein names (e.g., EGFR, ErbB1, HER1)
- Pathway or process-level context (e.g., signal transduction, apoptosis)
- Known diseases, phenotypes, or biological functions
- Cellular localization (e.g., nucleus, cytoplasm, membrane)
- Upstream/downstream molecular interactions
Rules:
- Each question must be **fully standalone** (no "this"/"it").
- Do not change the core intent—preserve the user's informational goal.
- Use appropriate biological terminology and Reactome-relevant concepts.
- Vary the **phrasing**, **focus**, or **biological angle** of each question.
- If the input is ambiguous, infer a biologically meaningful interpretation.
Output:
Return only a valid JSON array of 4 strings (no explanations, no metadata).
Do not include any explanations or metadata.
"""

prompt = ChatPromptTemplate.from_messages(
[("system", system_prompt), ("user", "Original Question: {rephrased_input}")]
)

return prompt | llm | StrOutputParser() | RunnableLambda(QueryExpansionParser)
Loading