Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 35 additions & 15 deletions src/agent/profiles/base.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,19 @@
from typing import Annotated, TypedDict
from typing import Annotated, Literal, TypedDict

from langchain_core.embeddings import Embeddings
from langchain_core.language_models.chat_models import BaseChatModel
from langchain_core.messages import BaseMessage
from langchain_core.runnables import Runnable, RunnableConfig
from langgraph.graph.message import add_messages

from agent.tasks.rephrase import create_rephrase_chain
from tools.external_search.state import SearchState, WebSearchResult
from tools.external_search.workflow import create_search_workflow
from tools.preprocessing.state import PreprocessingState
from tools.preprocessing.workflow import create_preprocessing_workflow

SAFETY_SAFE: Literal["true"] = "true"
SAFETY_UNSAFE: Literal["false"] = "false"
DEFAULT_LANGUAGE: str = "English"


class AdditionalContent(TypedDict, total=False):
Expand All @@ -20,39 +25,52 @@ class InputState(TypedDict, total=False):


class OutputState(TypedDict, total=False):
answer: str # primary LLM response that is streamed to the user
answer: str # LLM response streamed to the user
additional_content: AdditionalContent # sends on graph completion


class BaseState(InputState, OutputState, total=False):
rephrased_input: str # LLM-generated query from user input
rephrased_input: (
str # contextualized, LLM-generated standalone query from user input
)
chat_history: Annotated[list[BaseMessage], add_messages]
safety: str # LLM-assessed safety level of the user input
reason_unsafe: str


class BaseGraphBuilder:
# NOTE: Anything that is common to all graph builders goes here

def __init__(
self,
llm: BaseChatModel,
embedding: Embeddings,
) -> None:
self.rephrase_chain: Runnable = create_rephrase_chain(llm)
self.preprocessing_workflow: Runnable = create_preprocessing_workflow(llm)
self.search_workflow: Runnable = create_search_workflow(llm)

async def preprocess(self, state: BaseState, config: RunnableConfig) -> BaseState:
rephrased_input: str = await self.rephrase_chain.ainvoke(
{
"user_input": state["user_input"],
"chat_history": state["chat_history"],
},
result: PreprocessingState = await self.preprocessing_workflow.ainvoke(
PreprocessingState(
user_input=state["user_input"],
chat_history=state.get("chat_history", []),
),
config,
)
return BaseState(rephrased_input=rephrased_input)
mapped_state = BaseState(
rephrased_input=result.get("rephrased_input", ""),
safety=result.get("safety", SAFETY_SAFE),
reason_unsafe=result.get("reason_unsafe", ""),
)
merged_state = dict(state)
merged_state.update(mapped_state)
return BaseState(**merged_state)

async def postprocess(self, state: BaseState, config: RunnableConfig) -> BaseState:
search_results: list[WebSearchResult] = []
if config["configurable"]["enable_postprocess"]:
if (
state.get("safety") == SAFETY_SAFE
and config["configurable"]["enable_postprocess"]
):
result: SearchState = await self.search_workflow.ainvoke(
SearchState(
input=state["rephrased_input"],
Expand All @@ -61,6 +79,8 @@ async def postprocess(self, state: BaseState, config: RunnableConfig) -> BaseSta
config=RunnableConfig(callbacks=config["callbacks"]),
)
search_results = result["search_results"]
return BaseState(
additional_content=AdditionalContent(search_results=search_results)
merged_state = dict(state)
merged_state.update(
{"additional_content": AdditionalContent(search_results=search_results)}
)
return BaseState(**merged_state)
65 changes: 59 additions & 6 deletions src/agent/profiles/react_to_me.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
from typing import Any
from typing import Any, Literal

from langchain_core.embeddings import Embeddings
from langchain_core.language_models.chat_models import BaseChatModel
from langchain_core.messages import AIMessage, HumanMessage
from langchain_core.runnables import Runnable, RunnableConfig
from langgraph.graph.state import StateGraph

from agent.profiles.base import BaseGraphBuilder, BaseState
from agent.profiles.base import (DEFAULT_LANGUAGE, SAFETY_SAFE, SAFETY_UNSAFE,
BaseGraphBuilder, BaseState)
from agent.tasks.unsafe_answer import create_unsafe_answer_generator
from retrievers.reactome.rag import create_reactome_rag


Expand All @@ -23,6 +25,9 @@ def __init__(
super().__init__(llm, embedding)

# Create runnables (tasks & tools)
streaming_llm = llm.model_copy(update={"streaming": True})

self.unsafe_answer_generator = create_unsafe_answer_generator(streaming_llm)
self.reactome_rag: Runnable = create_reactome_rag(
llm, embedding, streaming=True
)
Expand All @@ -32,36 +37,84 @@ def __init__(
# Set up nodes
state_graph.add_node("preprocess", self.preprocess)
state_graph.add_node("model", self.call_model)
state_graph.add_node("generate_unsafe_response", self.generate_unsafe_response)
state_graph.add_node("postprocess", self.postprocess)
# Set up edges
state_graph.set_entry_point("preprocess")
state_graph.add_edge("preprocess", "model")
state_graph.add_conditional_edges(
"preprocess",
self.proceed_with_research,
{"Continue": "model", "Finish": "generate_unsafe_response"},
)
state_graph.add_edge("model", "postprocess")
state_graph.add_edge("generate_unsafe_response", "postprocess")
state_graph.set_finish_point("postprocess")

self.uncompiled_graph: StateGraph = state_graph

async def proceed_with_research(
self, state: ReactToMeState
) -> Literal["Continue", "Finish"]:
return "Continue" if state.get("safety") == SAFETY_SAFE else "Finish"

async def generate_unsafe_response(
self, state: ReactToMeState, config: RunnableConfig
) -> ReactToMeState:
final_answer_message = await self.unsafe_answer_generator.ainvoke(
{
"language": state.get("detected_language", DEFAULT_LANGUAGE),
"user_input": state.get("rephrased_input", state["user_input"]),
"reason_unsafe": state.get("reason_unsafe", ""),
},
config,
)

final_answer = (
final_answer_message.content
if hasattr(final_answer_message, "content")
else str(final_answer_message)
)

updated_state = dict(state)
updated_state.update(
chat_history=[
HumanMessage(state["user_input"]),
(
final_answer_message
if hasattr(final_answer_message, "content")
else AIMessage(final_answer)
),
],
answer=final_answer,
safety=SAFETY_UNSAFE,
additional_content={"search_results": []},
)
return ReactToMeState(**updated_state)

async def call_model(
self, state: ReactToMeState, config: RunnableConfig
) -> ReactToMeState:
result: dict[str, Any] = await self.reactome_rag.ainvoke(
{
"input": state["rephrased_input"],
"expanded_queries": state.get("expanded_queries", []),
"chat_history": (
state["chat_history"]
if state["chat_history"]
state.get("chat_history")
if state.get("chat_history")
else [HumanMessage(state["user_input"])]
),
},
config,
)
return ReactToMeState(
updated_state = dict(state)
updated_state.update(
chat_history=[
HumanMessage(state["user_input"]),
AIMessage(result["answer"]),
],
answer=result["answer"],
)
return ReactToMeState(**updated_state)


def create_reactome_graph(
Expand Down
67 changes: 48 additions & 19 deletions src/agent/tasks/safety_checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,28 +3,57 @@
from langchain_core.runnables import Runnable
from pydantic import BaseModel, Field

safety_check_message = """
You are an expert grader with a strong understanding of ethical guidelines and responsible AI usage.

Your task is to determine whether a user-submitted question is appropriate and not harmful.
Provide a binary output as either:
Yes: The question is appropriate and not harmful.
No: The question is harmful, offensive, inappropriate, or unethical.
"""

safety_check_prompt = ChatPromptTemplate.from_messages(
[
("system", safety_check_message),
("human", "User question: \n\n {input}"),
]
)


class SafetyCheck(BaseModel):
binary_score: str = Field(
description="Indicates whether the question is appropriate and related to molecular biology. Expected values: 'Yes' or 'No'."
safety: str = Field(
description="Indicates whether the question is appropriate and related to molecular biology. Expected values: 'true' or 'false'."
)
reason_unsafe: str = Field(
description="If 'safety' is false, briefly state the reason; if 'safety' is true, leave this field empty."
)


def create_safety_checker(llm: BaseChatModel) -> Runnable:
return safety_check_prompt | llm.with_structured_output(SafetyCheck)
"""
Create a safety checker chain.

Args:
llm: Language model to use

Returns:
Runnable that takes rephrased_input and returns SafetyCheck
"""
safety_check_message = """
You are an expert scientific assistant. You have advanced training in scientific ethics, dual-use assessment, and responsible AI.

Tasks:

1. Safety & Appropriateness Check
- Assess the question and determine if it is safe and appropriate for responsible AI use.
- The question and any plausible answer to it must NOT be harmful, offensive, unethical, or violate scientific or ethical guidelines.
- The question or its likely answers must NOT request, describe, or explore how to carry out harmful, illegal, dangerous, or dual-use activities (such as modifying pathogens, increasing infectivity, or other forms of misuse)—even if the scenario is hypothetical, fictional, or indirect.
- The question must NOT request personal, medical, legal, or other types of advice outside the scope of scientific knowledge.
- If a question is ambiguous or could plausibly be used for harm, treat it as unsafe.
- Treat hypothetical, fictional, or made-up scenarios with the same level of scrutiny as real-world questions.

2. Reactome Relevance Check
- Determine if the question is relevant to biology, life sciences, molecular biology, or related topics.
- Mark questions as not relevant if they are about unrelated topics (such as programming, math, history, trivia, etc.).

IMPORTANT:
- If the standalone question is unsafe or not relevant return "safety": "false".
- If the standalone question is both safe **and** relevant, return "safety": "true".

Return only a JSON object in the following format:
"safety": "true" or "false", // Use string, not boolean. Mark as "false" if unsafe OR not relevant.
"reason_unsafe": "..." // If 'safety' is false, briefly state the reason
"""

prompt = ChatPromptTemplate.from_messages(
[
("system", safety_check_message),
("human", "User question: \n\n {rephrased_input}"),
]
)

return prompt | llm.with_structured_output(SafetyCheck)
46 changes: 46 additions & 0 deletions src/agent/tasks/unsafe_answer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
from langchain_core.language_models.chat_models import BaseChatModel
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.runnables import Runnable


def create_unsafe_answer_generator(llm: BaseChatModel) -> Runnable:
"""
Create an unsafe answer generator chain.

Args:
llm: Language model to use.

Returns:
Runnable that generates refusal messages for unsafe or out-of-scope queries.
"""
system_prompt = """
You are an expert scientific assistant operating under the React-to-Me platform. React-to-Me helps both experts and non-experts explore molecular biology using trusted data from the Reactome database.

You have advanced training in scientific ethics, dual-use research concerns, and responsible AI use.

You will receive three inputs:
1. The user's question.
2. A system-generated variable called `reason_unsafe`, which explains why the question cannot be answered.
3. The user's preferred language (as a language code or name).

Your task is to clearly, respectfully, and firmly explain to the user *why* their question cannot be answered, based solely on the `reason_unsafe` input. Do **not** attempt to answer, rephrase, or guide the user toward answering the original question.

You must:
- Respond in the user’s preferred language.
- Politely explain the refusal, grounded in the `reason_unsafe`.
- Emphasize React-to-Me’s mission: to support responsible exploration of molecular biology through trusted databases.
- Suggest examples of appropriate topics (e.g., protein function, pathways, gene interactions using Reactome/UniProt).

You must not provide any workaround, implicit answer, or redirection toward unsafe content.
"""
prompt = ChatPromptTemplate.from_messages(
[
("system", system_prompt),
(
"user",
"Language:{language}\n\nQuestion:{user_input}\n\n Reason for unsafe or out of scope: {reason_unsafe}",
),
]
)

return prompt | llm
6 changes: 6 additions & 0 deletions src/tools/preprocessing/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
"""
Preprocessing utilities with reusable workflows and state definitions.
"""

from .state import PreprocessingState # noqa: F401
from .workflow import create_preprocessing_workflow # noqa: F401
18 changes: 18 additions & 0 deletions src/tools/preprocessing/state.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
from typing import TypedDict

from langchain_core.messages import BaseMessage


class PreprocessingState(TypedDict, total=False):
"""State for the preprocessing workflow."""

# Inputs
user_input: str
chat_history: list[BaseMessage]

# Task outputs
rephrased_input: str
safety: str
reason_unsafe: str
expanded_queries: list[str]
detected_language: str
Loading
Loading