This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from typing import Callable, Literal, Optional, Sequence, Union, cast | |
from langchain.chat_models.base import BaseChatModel | |
from langchain_core.runnables import RunnableConfig | |
from langchain_core.tools import BaseTool | |
from langgraph.func import END | |
from langgraph.graph import StateGraph | |
from langgraph.graph.graph import CompiledGraph | |
from langgraph.prebuilt import ToolNode | |
from langchain_core.messages import AIMessage, BaseMessage, ToolMessage | |
from langgraph.prebuilt.chat_agent_executor import AgentState, LanguageModelLike, Prompt, StateSchema, _get_prompt_runnable, _get_state_value, _should_bind_tools, _validate_chat_history | |
from langgraph.utils.runnable import RunnableCallable | |
from langgraph.types import Checkpointer | |
from langgraph.store.base import BaseStore | |
def create_custom_react_agent( | |
model: LanguageModelLike, | |
tools: Sequence[Union[BaseTool, Callable]], | |
customize_workflow: Callable[[StateGraph], StateGraph] = lambda x: x, | |
prompt: Optional[Prompt] = None, | |
name: Optional[str] = None, # agent name | |
checkpointer: Optional[Checkpointer] = None, | |
store: Optional[BaseStore] = None, | |
interrupt_before: Optional[list[str]] = None, | |
interrupt_after: Optional[list[str]] = None, | |
debug: bool = False, | |
) -> CompiledGraph: | |
state_schema = AgentState | |
tool_node = ToolNode(tools) | |
tool_classes = list(tool_node.tools_by_name.values()) | |
tool_calling_enabled = len(tool_classes) > 0 | |
# 도구가 있을경우, 모델에 도구를 바인드한다. | |
if _should_bind_tools(model, tool_classes) and tool_calling_enabled: | |
model = cast(BaseChatModel, model).bind_tools(tool_classes) | |
# 프롬프트와 모델을 체이닝한 러너블을 생성한다. | |
model_runnable = _get_prompt_runnable(prompt) | model | |
# return_direct가 True라면 이 도구가 불린이후로 AgentExecutor가 루핑을 멈춰야한다. | |
should_return_direct = {t.name for t in tool_classes if t.return_direct} | |
# 에이전트 그래프에서 에이전트가 더 많은 단계를 수행해야하는지 여부 | |
# True - 더 많은 단계가 필요 | |
# False - 작업이 완료되었거나 더 이상 단계가 필요하지 않음 | |
def _are_more_steps_needed(state: StateSchema, response: BaseMessage) -> bool: | |
# Agent가 도구를 호출했는지 여부, 도구를 호출했다면 응답을 받아야하므로 단계가 더 필요하다. | |
has_tool_calls = isinstance(response, AIMessage) and response.tool_calls | |
# response가 AIMessage라면 방금 호출한 Tool Call이 모두 return_direct가 True인지 확인 | |
all_tools_return_direct = ( | |
all(call["name"] in should_return_direct for call in response.tool_calls) | |
if isinstance(response, AIMessage) | |
else False | |
) | |
# remaining_steps는 그래프가 재귀제한에 도달하지 않도록 하기위해 사용된다. | |
remaining_steps = _get_state_value(state, "remaining_steps", None) | |
is_last_step = _get_state_value(state, "is_last_step", False) | |
return ( | |
# 마지막 단계에 도달했는데 응답에 도구호출이 포함되어있으므로, 도구 호출 결과처리를 위한 추가 단계가 필요. | |
(remaining_steps is None and is_last_step and has_tool_calls) | |
# 남은 단계가 없지만 return_direct가 True인 도구가 있으므로, | |
# 추가적으로 이 도구의 결과를 처리하기 위한 단계가 필요하다. | |
or ( | |
remaining_steps is not None | |
and remaining_steps < 1 | |
and all_tools_return_direct | |
) | |
# 남은 단계가 거의 없는 상황에서 도구 호출을 해버린 경우. | |
or (remaining_steps is not None and remaining_steps < 2 and has_tool_calls) | |
) | |
def call_model(state: StateSchema, config: RunnableConfig) -> StateSchema: | |
messages = _get_state_value(state, "messages") | |
_validate_chat_history(messages) | |
response = cast(AIMessage, model_runnable.invoke(state, config)) | |
# add agent name to the AIMessage | |
response.name = name | |
if _are_more_steps_needed(state, response): | |
return { | |
"messages": [ | |
AIMessage( | |
id=response.id, | |
content="Sorry, need more steps to process this request.", | |
) | |
] | |
} | |
# We return a list, because this will get added to the existing list | |
return {"messages": [response]} | |
async def acall_model(state: StateSchema, config: RunnableConfig) -> StateSchema: | |
messages = _get_state_value(state, "messages") | |
_validate_chat_history(messages) | |
response = cast(AIMessage, await model_runnable.ainvoke(state, config)) | |
# add agent name to the AIMessage | |
response.name = name | |
if _are_more_steps_needed(state, response): | |
return { | |
"messages": [ | |
AIMessage( | |
id=response.id, | |
content="Sorry, need more steps to process this request.", | |
) | |
] | |
} | |
# We return a list, because this will get added to the existing list | |
return {"messages": [response]} | |
def should_continue(state: StateSchema) -> Union[str, list]: | |
messages = _get_state_value(state, "messages") | |
last_message = messages[-1] | |
# 마지막 메세지가 도구호출을 하지않았다면 끝낼 수 있음 | |
if not isinstance(last_message, AIMessage) or not last_message.tool_calls: | |
return END | |
else: | |
return "tools" | |
should_continue_destinations = ["tools", END] | |
workflow = StateGraph(state_schema) | |
workflow.add_node("agent", RunnableCallable(call_model, acall_model)) | |
workflow.add_node("tools", tool_node) | |
workflow.set_entry_point("agent") | |
workflow.add_conditional_edges( | |
"agent", | |
should_continue, | |
path_map=should_continue_destinations | |
) | |
def route_tool_responses(state: StateSchema) -> Literal["agent", "__end__"]: | |
# 마지막 메세지부터 살펴봄 | |
for m in reversed(_get_state_value(state, "messages")): | |
if not isinstance(m, ToolMessage): | |
break | |
# 가장 마지막 도구호출이 should_return_direct인 경우 바로 종료해도됨 | |
if m.name in should_return_direct: | |
return END | |
# 아니라면 다시 agent로 돌아가야함 | |
return "agent" | |
if should_return_direct: | |
workflow.add_conditional_edges("tools", route_tool_responses) | |
else: | |
workflow.add_edge("tools", "agent") | |
workflow = customize_workflow(workflow) | |
return workflow.compile( | |
checkpointer=checkpointer, | |
store=store, | |
interrupt_before=interrupt_before, | |
interrupt_after=interrupt_after, | |
debug=debug, | |
name=name, | |
) |
Langgraph.prebuilt의 create_react_agent
를 이용해 ReAct Agent를 생성하면 이미 StateGraph가 컴파일되어서 반환되기때문에 이걸로 Agent를 만들면 워크플로우를 수정할 수 없다. Agent를 개발하다보니 추가적인 단계를 정의해야할 일이 생겨 더 이상 create_react_agent
를 사용할 수 없었다.
그래서 create_react_agent의 코드를 분석해서 커스터마이즈 가능하도록 일부 코드들을 들고와서 create_custom_react_agent
를 만들었다. StateGraph가 컴파일 되기전에 customize_workflow
로 graph를 받아서 추가로 workflow를 수정한 뒤에 컴파일 할 수 있다.
주석도 열심히 달아두었기때문에 처음부터 ReAct Agent를 구축하려는 분들에게도 도움이 될 것같다.
사용 예시

아무것도 건드리지 않았을때는 이렇게 기본형태의 ReAct Agent 그래프를 볼 수 있다.
def customize_workflow(graph: StateGraph) -> StateGraph: graph.add_node("load_memories", load_memories) graph.add_edge("load_memories", "agent") graph.set_entry_point("load_memories") graph.edges.remove((START, "agent")) return graph self.agent = create_custom_react_agent( model=llm, tools=tools, prompt=prompt, checkpointer=memory, customize_workflow=customize_workflow, debug=debug )
LTM 추가를 위해 메모리를 불러오는 작업(load_memories)를 Entrypoint로 변경하는 작업을 수행하면 다음과 같이 그래프가 변경된다.

'프로그래밍 > AI,ML' 카테고리의 다른 글
LangGraph Agent에 장기기억(LTM)추가하기 (0) | 2025.04.03 |
---|---|
반쪽짜리 Contextual Retrieval로 RAG 강화 해보기 (0) | 2025.03.26 |
사내 AI Agent 구축기 (1) | 2025.03.11 |
[ComfyUI] Workflow를 Python API로 만들기 (0) | 2025.01.21 |
[ComfyUI] AI를 이용한 배너광고 자동 생성 워크플로우 (6) | 2025.01.17 |