프로그래밍/AI,ML

LangGraph ReAct Agent 커스터마이즈하기

Lou Park 2025. 4. 1. 17:22
from typing import Callable, Literal, Optional, Sequence, Union, cast
from langchain.chat_models.base import BaseChatModel
from langchain_core.runnables import RunnableConfig
from langchain_core.tools import BaseTool
from langgraph.func import END
from langgraph.graph import StateGraph
from langgraph.graph.graph import CompiledGraph
from langgraph.prebuilt import ToolNode
from langchain_core.messages import AIMessage, BaseMessage, ToolMessage
from langgraph.prebuilt.chat_agent_executor import AgentState, LanguageModelLike, Prompt, StateSchema, _get_prompt_runnable, _get_state_value, _should_bind_tools, _validate_chat_history
from langgraph.utils.runnable import RunnableCallable
from langgraph.types import Checkpointer
from langgraph.store.base import BaseStore
def create_custom_react_agent(
model: LanguageModelLike,
tools: Sequence[Union[BaseTool, Callable]],
customize_workflow: Callable[[StateGraph], StateGraph] = lambda x: x,
prompt: Optional[Prompt] = None,
name: Optional[str] = None, # agent name
checkpointer: Optional[Checkpointer] = None,
store: Optional[BaseStore] = None,
interrupt_before: Optional[list[str]] = None,
interrupt_after: Optional[list[str]] = None,
debug: bool = False,
) -> CompiledGraph:
state_schema = AgentState
tool_node = ToolNode(tools)
tool_classes = list(tool_node.tools_by_name.values())
tool_calling_enabled = len(tool_classes) > 0
# 도구가 있을경우, 모델에 도구를 바인드한다.
if _should_bind_tools(model, tool_classes) and tool_calling_enabled:
model = cast(BaseChatModel, model).bind_tools(tool_classes)
# 프롬프트와 모델을 체이닝한 러너블을 생성한다.
model_runnable = _get_prompt_runnable(prompt) | model
# return_direct가 True라면 이 도구가 불린이후로 AgentExecutor가 루핑을 멈춰야한다.
should_return_direct = {t.name for t in tool_classes if t.return_direct}
# 에이전트 그래프에서 에이전트가 더 많은 단계를 수행해야하는지 여부
# True - 더 많은 단계가 필요
# False - 작업이 완료되었거나 더 이상 단계가 필요하지 않음
def _are_more_steps_needed(state: StateSchema, response: BaseMessage) -> bool:
# Agent가 도구를 호출했는지 여부, 도구를 호출했다면 응답을 받아야하므로 단계가 더 필요하다.
has_tool_calls = isinstance(response, AIMessage) and response.tool_calls
# response가 AIMessage라면 방금 호출한 Tool Call이 모두 return_direct가 True인지 확인
all_tools_return_direct = (
all(call["name"] in should_return_direct for call in response.tool_calls)
if isinstance(response, AIMessage)
else False
)
# remaining_steps는 그래프가 재귀제한에 도달하지 않도록 하기위해 사용된다.
remaining_steps = _get_state_value(state, "remaining_steps", None)
is_last_step = _get_state_value(state, "is_last_step", False)
return (
# 마지막 단계에 도달했는데 응답에 도구호출이 포함되어있으므로, 도구 호출 결과처리를 위한 추가 단계가 필요.
(remaining_steps is None and is_last_step and has_tool_calls)
# 남은 단계가 없지만 return_direct가 True인 도구가 있으므로,
# 추가적으로 이 도구의 결과를 처리하기 위한 단계가 필요하다.
or (
remaining_steps is not None
and remaining_steps < 1
and all_tools_return_direct
)
# 남은 단계가 거의 없는 상황에서 도구 호출을 해버린 경우.
or (remaining_steps is not None and remaining_steps < 2 and has_tool_calls)
)
def call_model(state: StateSchema, config: RunnableConfig) -> StateSchema:
messages = _get_state_value(state, "messages")
_validate_chat_history(messages)
response = cast(AIMessage, model_runnable.invoke(state, config))
# add agent name to the AIMessage
response.name = name
if _are_more_steps_needed(state, response):
return {
"messages": [
AIMessage(
id=response.id,
content="Sorry, need more steps to process this request.",
)
]
}
# We return a list, because this will get added to the existing list
return {"messages": [response]}
async def acall_model(state: StateSchema, config: RunnableConfig) -> StateSchema:
messages = _get_state_value(state, "messages")
_validate_chat_history(messages)
response = cast(AIMessage, await model_runnable.ainvoke(state, config))
# add agent name to the AIMessage
response.name = name
if _are_more_steps_needed(state, response):
return {
"messages": [
AIMessage(
id=response.id,
content="Sorry, need more steps to process this request.",
)
]
}
# We return a list, because this will get added to the existing list
return {"messages": [response]}
def should_continue(state: StateSchema) -> Union[str, list]:
messages = _get_state_value(state, "messages")
last_message = messages[-1]
# 마지막 메세지가 도구호출을 하지않았다면 끝낼 수 있음
if not isinstance(last_message, AIMessage) or not last_message.tool_calls:
return END
else:
return "tools"
should_continue_destinations = ["tools", END]
workflow = StateGraph(state_schema)
workflow.add_node("agent", RunnableCallable(call_model, acall_model))
workflow.add_node("tools", tool_node)
workflow.set_entry_point("agent")
workflow.add_conditional_edges(
"agent",
should_continue,
path_map=should_continue_destinations
)
def route_tool_responses(state: StateSchema) -> Literal["agent", "__end__"]:
# 마지막 메세지부터 살펴봄
for m in reversed(_get_state_value(state, "messages")):
if not isinstance(m, ToolMessage):
break
# 가장 마지막 도구호출이 should_return_direct인 경우 바로 종료해도됨
if m.name in should_return_direct:
return END
# 아니라면 다시 agent로 돌아가야함
return "agent"
if should_return_direct:
workflow.add_conditional_edges("tools", route_tool_responses)
else:
workflow.add_edge("tools", "agent")
workflow = customize_workflow(workflow)
return workflow.compile(
checkpointer=checkpointer,
store=store,
interrupt_before=interrupt_before,
interrupt_after=interrupt_after,
debug=debug,
name=name,
)
view raw react.py hosted with ❤ by GitHub

 

Langgraph.prebuilt의 create_react_agent를 이용해 ReAct Agent를 생성하면 이미 StateGraph가 컴파일되어서 반환되기때문에 이걸로 Agent를 만들면 워크플로우를 수정할 수 없다. Agent를 개발하다보니 추가적인 단계를 정의해야할 일이 생겨 더 이상 create_react_agent를 사용할 수 없었다.

 

그래서 create_react_agent의 코드를 분석해서 커스터마이즈 가능하도록 일부 코드들을 들고와서 create_custom_react_agent를 만들었다. StateGraph가 컴파일 되기전에 customize_workflow로 graph를 받아서 추가로 workflow를 수정한 뒤에 컴파일 할 수 있다.

 

주석도 열심히 달아두었기때문에 처음부터 ReAct Agent를 구축하려는 분들에게도 도움이 될 것같다.

 

 

사용 예시


아무것도 건드리지 않았을때는 이렇게 기본형태의 ReAct Agent 그래프를 볼 수 있다.

def customize_workflow(graph: StateGraph) -> StateGraph:
graph.add_node("load_memories", load_memories)
graph.add_edge("load_memories", "agent")
graph.set_entry_point("load_memories")
graph.edges.remove((START, "agent"))
return graph
self.agent = create_custom_react_agent(
model=llm,
tools=tools,
prompt=prompt,
checkpointer=memory,
customize_workflow=customize_workflow,
debug=debug
)

 

 

LTM 추가를 위해 메모리를 불러오는 작업(load_memories)를 Entrypoint로 변경하는 작업을 수행하면 다음과 같이 그래프가 변경된다.