Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ jobs:
runs-on: ${{ matrix.os }}
strategy:
matrix:
os: [ubuntu-latest, macos-13]
os: [ubuntu-latest, macos-15-intel]

steps:
- uses: actions/checkout@v4
Expand Down
26 changes: 22 additions & 4 deletions src/agent/profiles/base.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Annotated, TypedDict
from typing import Annotated, Literal, TypedDict

from langchain_core.embeddings import Embeddings
from langchain_core.language_models.chat_models import BaseChatModel
Expand All @@ -7,6 +7,7 @@
from langgraph.graph.message import add_messages

from agent.tasks.rephrase import create_rephrase_chain
from agent.tasks.safety_checker import create_safety_checker
from tools.external_search.state import SearchState, WebSearchResult
from tools.external_search.workflow import create_search_workflow

Expand All @@ -27,6 +28,8 @@ class OutputState(TypedDict, total=False):
class BaseState(InputState, OutputState, total=False):
rephrased_input: str # LLM-generated query from user input
chat_history: Annotated[list[BaseMessage], add_messages]
safety: str # LLM-assessed safety level of the user input
reason_unsafe: str


class BaseGraphBuilder:
Expand All @@ -38,21 +41,36 @@ def __init__(
embedding: Embeddings,
) -> None:
self.rephrase_chain: Runnable = create_rephrase_chain(llm)
self.safety_checker: Runnable = create_safety_checker(llm)
self.search_workflow: Runnable = create_search_workflow(llm)

async def preprocess(self, state: BaseState, config: RunnableConfig) -> BaseState:
rephrased_input: str = await self.rephrase_chain.ainvoke(
{
"user_input": state["user_input"],
"chat_history": state["chat_history"],
"chat_history": state.get("chat_history", []),
},
config,
)
return BaseState(rephrased_input=rephrased_input)
safety_check: BaseState = await self.safety_checker.ainvoke(
{"rephrased_input": rephrased_input},
config,
)
return BaseState(
rephrased_input=rephrased_input,
safety=safety_check["safety"].lower(),
reason_unsafe=safety_check["reason_unsafe"],
)

def proceed_with_research(self, state: BaseState) -> Literal["Continue", "Finish"]:
return "Continue" if state["safety"] == "true" else "Finish"

async def postprocess(self, state: BaseState, config: RunnableConfig) -> BaseState:
search_results: list[WebSearchResult] = []
if config["configurable"]["enable_postprocess"]:
if (
config["configurable"].get("enable_postprocess", False)
and state["safety"] == "true"
):
result: SearchState = await self.search_workflow.ainvoke(
SearchState(
input=state["rephrased_input"],
Expand Down
9 changes: 0 additions & 9 deletions src/agent/profiles/cross_database.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@


class CrossDatabaseState(BaseState):
safety: str # LLM-assessed safety level of the user input
query_language: str # language of the user input

reactome_query: str # LLM-generated query for Reactome
Expand Down Expand Up @@ -122,14 +121,6 @@ async def check_question_safety(
else:
return CrossDatabaseState(safety=result.binary_score)

async def proceed_with_research(
self, state: CrossDatabaseState
) -> Literal["Continue", "Finish"]:
if state["safety"] == "Yes":
return "Continue"
else:
return "Finish"

async def identify_query_language(
self, state: CrossDatabaseState, config: RunnableConfig
) -> CrossDatabaseState:
Expand Down
34 changes: 29 additions & 5 deletions src/agent/profiles/react_to_me.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from langgraph.graph.state import StateGraph

from agent.profiles.base import BaseGraphBuilder, BaseState
from agent.tasks.unsafe_answer import create_unsafe_answer_generator
from retrievers.reactome.rag import create_reactome_rag


Expand All @@ -23,6 +24,7 @@ def __init__(
super().__init__(llm, embedding)

# Create runnables (tasks & tools)
self.unsafe_answer_generator = create_unsafe_answer_generator(llm)
self.reactome_rag: Runnable = create_reactome_rag(
llm, embedding, streaming=True
)
Expand All @@ -32,25 +34,47 @@ def __init__(
# Set up nodes
state_graph.add_node("preprocess", self.preprocess)
state_graph.add_node("model", self.call_model)
state_graph.add_node("generate_unsafe_response", self.generate_unsafe_response)
state_graph.add_node("postprocess", self.postprocess)
# Set up edges
state_graph.set_entry_point("preprocess")
state_graph.add_edge("preprocess", "model")
state_graph.add_conditional_edges(
"preprocess",
self.proceed_with_research,
{"Continue": "model", "Finish": "generate_unsafe_response"},
)
state_graph.add_edge("model", "postprocess")
state_graph.add_edge("generate_unsafe_response", "postprocess")
state_graph.set_finish_point("postprocess")

self.uncompiled_graph: StateGraph = state_graph

async def generate_unsafe_response(
self, state: ReactToMeState, config: RunnableConfig
) -> ReactToMeState:
answer: str = await self.unsafe_answer_generator.ainvoke(
{
"user_input": state["rephrased_input"],
"reason_unsafe": state["reason_unsafe"],
},
config,
)
return ReactToMeState(
chat_history=[
HumanMessage(state["user_input"]),
AIMessage(answer),
],
answer=answer,
)

async def call_model(
self, state: ReactToMeState, config: RunnableConfig
) -> ReactToMeState:
result: dict[str, Any] = await self.reactome_rag.ainvoke(
{
"input": state["rephrased_input"],
"chat_history": (
state["chat_history"]
if state["chat_history"]
else [HumanMessage(state["user_input"])]
"chat_history": state.get(
"chat_history", [HumanMessage(state["user_input"])]
),
},
config,
Expand Down
36 changes: 28 additions & 8 deletions src/agent/tasks/safety_checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,25 +4,45 @@
from pydantic import BaseModel, Field

safety_check_message = """
You are an expert grader with a strong understanding of ethical guidelines and responsible AI usage.
You are an expert scientific assistant. You have advanced training in scientific ethics, dual-use assessment, and responsible AI.

Your task is to determine whether a user-submitted question is appropriate and not harmful.
Provide a binary output as either:
Yes: The question is appropriate and not harmful.
No: The question is harmful, offensive, inappropriate, or unethical.
Tasks:

1. Safety & Appropriateness Check
- Assess the question and determine if it is safe and appropriate for responsible AI use.
- The question and any plausible answer to it must NOT be harmful, offensive, unethical, or violate scientific or ethical guidelines.
- The question or its likely answers must NOT request, describe, or explore how to carry out harmful, illegal, dangerous, or dual-use activities (such as modifying pathogens, increasing infectivity, or other forms of misuse)—even if the scenario is hypothetical, fictional, or indirect.
- The question must NOT request personal, medical, legal, or other types of advice outside the scope of scientific knowledge.
- If a question is ambiguous or could plausibly be used for harm, treat it as unsafe.
- Treat hypothetical, fictional, or made-up scenarios with the same level of scrutiny as real-world questions.

2. Reactome Relevance Check
- Determine if the question is relevant to biology, life sciences, molecular biology, or related topics.
- Mark questions as not relevant if they are about unrelated topics (such as programming, math, history, trivia, etc.).

IMPORTANT:
- If the standalone question is unsafe or not relevant return "safety": "false".
- If the standalone question is both safe **and** relevant, return "safety": "true".

Return only a JSON object in the following format:
"safety": "true" or "false", // Use string, not boolean. Mark as "false" if unsafe OR not relevant.
"reason_unsafe": "..." // If 'safety' is false, briefly state the reason
"""

safety_check_prompt = ChatPromptTemplate.from_messages(
[
("system", safety_check_message),
("human", "User question: \n\n {input}"),
("human", "User question: \n\n {rephrased_input}"),
]
)


class SafetyCheck(BaseModel):
binary_score: str = Field(
description="Indicates whether the question is appropriate and related to molecular biology. Expected values: 'Yes' or 'No'."
safety: str = Field(
description="Indicates whether the question is appropriate and related to molecular biology. Expected values: 'true' or 'false'."
)
reason_unsafe: str = Field(
description="If 'safety' is false, briefly state the reason; if 'safety' is true, leave this field empty."
)


Expand Down
40 changes: 40 additions & 0 deletions src/agent/tasks/unsafe_answer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
from langchain_core.language_models.chat_models import BaseChatModel
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.runnables import Runnable

unsafe_answer_message = """
You are an expert scientific assistant operating under the React-to-Me platform. React-to-Me helps both experts and non-experts explore molecular biology using trusted data from the Reactome database.

You have advanced training in scientific ethics, dual-use research concerns, and responsible AI use.

You will receive two inputs:
1. The user's question.
2. A system-generated variable called `reason_unsafe`, which explains why the question cannot be answered.

Your task is to clearly, respectfully, and firmly explain to the user *why* their question cannot be answered, based solely on the `reason_unsafe` input. Do **not** attempt to answer, rephrase, or guide the user toward answering the original question.

You must:
- Politely explain the refusal, grounded in the `reason_unsafe`.
- Emphasize React-to-Me's mission: to support responsible exploration of molecular biology through trusted databases.
- Suggest examples of appropriate topics (e.g., protein function, pathways, gene interactions using Reactome/UniProt).

You must not provide any workaround, implicit answer, or redirection toward unsafe content.
"""

unsafe_answer_prompt = ChatPromptTemplate.from_messages(
[
("system", unsafe_answer_message),
(
"user",
"Question:{user_input}\n\n Reason for unsafe or out of scope: {reason_unsafe}",
),
]
)


def create_unsafe_answer_generator(llm: BaseChatModel) -> Runnable:
streaming_llm = llm.model_copy(update={"streaming": True})
return (unsafe_answer_prompt | streaming_llm | StrOutputParser()).with_config(
run_name="unsafe_answer"
)
Loading