Enhancing Text-to-SQL Agents with Step-by-Step Reasoning
One cool outcome of the DeepSeek R1 release is that LLM is now starting to show the Thinking <think>
tokens in the response, similar to ChatGPT-o1 and o3-mimi. Encouraging an LLM to think more deeply has a lot of benefits:
- No more black-box answers! You can see the reasoning behind your LLM’s responses in real-time.
- Users get insight into how the model reaches its conclusions.
- Spot and fix prompt mistakes with clarity.
- Transparency makes AI decisions feel more reliable.
- When humans and AI share reasoning, working together becomes effortless.
So here we are, I’ve built a RAG that brings a similar reasoning process (CoT responses) to the LangGraph SQL agent with tool calling. It is a ReAct agent (Reason + Act) that combines LangGraph’s SQL toolkit with a graph-based execution. Here’s how it works:
Now, let’s understand the thinking process.
The agent starts with a system prompt that structures its thinking:
I’ve mapped out the exact steps our SQL agent takes, from the moment it receives a question until it returns the final query:
Four-Phase Thinking Process
Reasoning Phase (<reasoning>
tag)
- Explains information needs
- Describes expected outcomes
- Identifies challenges
- Justifies approach
Analysis Phase (<analysis>
tag)
- Tables and joins needed
- Required columns
- Filters and conditions
- Ordering/grouping logic
Query Phase (<query>
tag)
- Constructs SQL following rules:
- SELECT statements only
- Proper syntax
- Default LIMIT 10
- Verified schema
Verification Phase (<error_check>
and <final_check>
tags)
- Validates reasoning
- Confirms approach
- Checks completeness
- Verifies output
Here’s a visualization of the process:
Here’s a full prompt template:
query_gen_system = """
I am an SQL expert who helps analyze database queries. I have access to tools for interacting with the database. When given a question, I'll think through it carefully and explain my reasoning in natural language.
Then I'll walk through my analysis process:
1. First, I'll understand what tables and data I need
2. Then, I'll verify the schema and relationships
3. Finally, I'll construct an appropriate SQL query
For each query, I'll think about:
- What tables are involved and how they connect
- Any special conditions or filters needed
- How to handle potential edge cases
- The most efficient way to get the results
<reasoning>
I will **always** include this section before writing a query. Here, I will:
- Explain what information I need and why
- Describe my expected outcome
- Identify potential challenges
- Justify my query structure
If this section is missing, I will rewrite my response to include it.
</reasoning>
<analysis>
Here I break down the key components needed for the query:
- Required tables and joins
- Important columns and calculations
- Any specific filters or conditions
- Proper ordering and grouping
</analysis>
<query>
The final SQL query
</query>
<error_check>
If there's an error, I'll explain:
- What went wrong
- Why it happened
- How to fix it
</error_check>
<final_check>
Before finalizing, I will verify:
- Did I include a clear reasoning section?
- Did I explain my approach before querying?
- Did I provide an analysis of the query structure?
- If any of these are missing, I will revise my response.
</final_check>
Important rules:
1. Only use SELECT statements, no modifications
2. Verify all schema assumptions
3. Use proper SQLite syntax
4. Limit results to 10 unless specified
5. Double-check all joins and conditions
6. Always include tool_analysis and tool_reasoning for each tool call
"""
The main part of our agent’s thinking process is complete — we’ve covered the flow and the detailed prompt that guides its reasoning. Now, let’s move to the next part: Building the LangGraph SQL Agent.
First, let’s look at the graph implementation:
query_gen_prompt = ChatPromptTemplate.from_messages([
("system", query_gen_system),
MessagesPlaceholder(variable_name="messages"),
])
query_gen_model = query_gen_prompt | ChatOpenAI(
model="gpt-4o-mini", temperature=0).bind_tools(tools=sql_db_toolkit_tools)
class State(TypedDict):
messages: Annotated[list, add_messages]
graph_builder = StateGraph(State)
def query_gen_node(state: State):
return {"messages": [query_gen_model.invoke(state["messages"])]}
checkpointer = MemorySaver()
graph_builder.add_node("query_gen", query_gen_node)
query_gen_tools_node = ToolNode(tools=sql_db_toolkit_tools)
graph_builder.add_node("query_gen_tools", query_gen_tools_node)
graph_builder.add_conditional_edges(
"query_gen",
tools_condition,
{"tools": "query_gen_tools", END: END},
)
graph_builder.add_edge("query_gen_tools", "query_gen")
graph_builder.set_entry_point("query_gen")
graph = graph_builder.compile(checkpointer=checkpointer)
Now, here’s the crucial part — how we extract and process the thinking process from our agent’s responses:
- Extracts each thinking phase from reasoning tags we defined
- Formats the output in a readable way
- Captures the final SQL query when generated
- Shows the agent’s thought process in real-time
def extract_section(text: str, section: str) -> str:
pattern = f"<{section}>(.*?)</{section}>"
match = re.search(pattern, text, re.DOTALL)
return match.group(1).strip() if match else ""
def process_event(event: Dict[str, Any]) -> Optional[str]:
if 'query_gen' in event:
messages = event['query_gen']['messages']
for message in messages:
content = message.content if hasattr(message, 'content') else ""
reasoning = extract_section(content, "reasoning")
if reasoning:
print(format_section("", reasoning))
analysis = extract_section(content, "analysis")
if analysis:
print(format_section("", analysis))
error_check = extract_section(content, "error_check")
if error_check:
print(format_section("", error_check))
final_check = extract_section(content, "final_check")
if final_check:
print(format_section("", final_check))
if hasattr(message, 'tool_calls'):
for tool_call in message.tool_calls:
tool_name = tool_call['name']
if tool_name == 'sql_db_query':
return tool_call['args']['query']
query = extract_section(content, "query")
if query:
# Try to extract SQL between triple backticks
sql_match = re.search(
r'```sql\n(.*?)\n```', query, re.DOTALL)
if sql_match:
return format_section("", query)
return None
To use it, we simply stream the result from the graph.stream:
def run_query(query_text: str):
print(f"\nAnalyzing: {query_text}")
for event in graph.stream({"messages": [("user", query_text)]},
config={"configurable": {"thread_id": 12}}):
if sql := process_event(event):
print(f"\nGenerated SQL: {sql}")
return sql
here’s the complete code to make this all work:
import os
from typing import Dict, Any
import re
from typing_extensions import TypedDict
from typing import Annotated, Optional
from langchain_community.agent_toolkits import SQLDatabaseToolkit
from langchain_community.utilities import SQLDatabase
from sqlalchemy import create_engine
from langchain_openai import ChatOpenAI
from langgraph.prebuilt import ToolNode, tools_condition
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langgraph.graph import END, StateGraph
from langgraph.graph.message import add_messages
from langgraph.checkpoint.memory import MemorySaver
def _set_env(key: str):
if key not in os.environ:
os.environ['OPENAI_API_KEY'] = key
_set_env("API_KEY")
db_file = "chinook.db"
engine = create_engine(f"sqlite:///{db_file}")
db = SQLDatabase(engine=engine)
toolkit = SQLDatabaseToolkit(db=db, llm=ChatOpenAI(model="gpt-4o-mini"))
sql_db_toolkit_tools = toolkit.get_tools()
query_gen_system = """
I am an SQL expert who helps analyze database queries. I have access to tools for interacting with the database. When given a question, I'll think through it carefully and explain my reasoning in natural language.
Then I'll walk through my analysis process:
1. First, I'll understand what tables and data I need
2. Then, I'll verify the schema and relationships
3. Finally, I'll construct an appropriate SQL query
For each query, I'll think about:
- What tables are involved and how they connect
- Any special conditions or filters needed
- How to handle potential edge cases
- The most efficient way to get the results
<reasoning>
I will **always** include this section before writing a query. Here, I will:
- Explain what information I need and why
- Describe my expected outcome
- Identify potential challenges
- Justify my query structure
If this section is missing, I will rewrite my response to include it.
</reasoning>
<analysis>
Here I break down the key components needed for the query:
- Required tables and joins
- Important columns and calculations
- Any specific filters or conditions
- Proper ordering and grouping
</analysis>
<query>
The final SQL query
</query>
<error_check>
If there's an error, I'll explain:
- What went wrong
- Why it happened
- How to fix it
</error_check>
<final_check>
Before finalizing, I will verify:
- Did I include a clear reasoning section?
- Did I explain my approach before querying?
- Did I provide an analysis of the query structure?
- If any of these are missing, I will revise my response.
</final_check>
Important rules:
1. Only use SELECT statements, no modifications
2. Verify all schema assumptions
3. Use proper SQLite syntax
4. Limit results to 10 unless specified
5. Double-check all joins and conditions
6. Always include tool_analysis and tool_reasoning for each tool call
"""
query_gen_prompt = ChatPromptTemplate.from_messages([
("system", query_gen_system),
MessagesPlaceholder(variable_name="messages"),
])
query_gen_model = query_gen_prompt | ChatOpenAI(
model="gpt-4o-mini", temperature=0).bind_tools(tools=sql_db_toolkit_tools)
class State(TypedDict):
messages: Annotated[list, add_messages]
graph_builder = StateGraph(State)
def query_gen_node(state: State):
return {"messages": [query_gen_model.invoke(state["messages"])]}
checkpointer = MemorySaver()
graph_builder.add_node("query_gen", query_gen_node)
query_gen_tools_node = ToolNode(tools=sql_db_toolkit_tools)
graph_builder.add_node("query_gen_tools", query_gen_tools_node)
graph_builder.add_conditional_edges(
"query_gen",
tools_condition,
{"tools": "query_gen_tools", END: END},
)
graph_builder.add_edge("query_gen_tools", "query_gen")
graph_builder.set_entry_point("query_gen")
graph = graph_builder.compile(checkpointer=checkpointer)
def format_section(title: str, content: str) -> str:
if not content:
return ""
return f"\n{content}\n"
def extract_section(text: str, section: str) -> str:
pattern = f"<{section}>(.*?)</{section}>"
match = re.search(pattern, text, re.DOTALL)
return match.group(1).strip() if match else ""
def process_event(event: Dict[str, Any]) -> Optional[str]:
if 'query_gen' in event:
messages = event['query_gen']['messages']
for message in messages:
content = message.content if hasattr(message, 'content') else ""
reasoning = extract_section(content, "reasoning")
if reasoning:
print(format_section("", reasoning))
analysis = extract_section(content, "analysis")
if analysis:
print(format_section("", analysis))
error_check = extract_section(content, "error_check")
if error_check:
print(format_section("", error_check))
final_check = extract_section(content, "final_check")
if final_check:
print(format_section("", final_check))
if hasattr(message, 'tool_calls'):
for tool_call in message.tool_calls:
tool_name = tool_call['name']
if tool_name == 'sql_db_query':
return tool_call['args']['query']
query = extract_section(content, "query")
if query:
sql_match = re.search(
r'```sql\n(.*?)\n```', query, re.DOTALL)
if sql_match:
return format_section("", query)
return None
def run_query(query_text: str):
print(f"\nAnalyzing your question: {query_text}")
final_sql = None
for event in graph.stream({"messages": [("user", query_text)]},
config={"configurable": {"thread_id": 12}}):
sql = process_event(event)
if sql:
final_sql = sql
if final_sql:
print(
"\nBased on my analysis, here's the SQL query that will answer your question:")
print(f"\n{final_sql}")
return final_sql
def interactive_sql():
print("\nWelcome to the SQL Assistant! Type 'exit' to quit.")
while True:
try:
query = input("\nWhat would you like to know? ")
if query.lower() in ['exit', 'quit']:
print("\nThank you for using SQL Assistant!")
break
run_query(query)
except KeyboardInterrupt:
print("\nThank you for using SQL Assistant!")
break
except Exception as e:
print(f"\nAn error occurred: {str(e)}")
print("Please try again with a different query.")
if __name__ == "__main__":
interactive_sql()
Let’s run it and take a look! Here’s the agent in action:
I’ve tested this implementation with several models (gpt4o, gpt4o-mini, Claude 3.5 Haiku), and the results are promising. Here is a sample thinking output:
What are the top 5 best-selling tracks by revenue?
Analyzing your question: What are the top 5 best-selling tracks by revenue?
To determine the top 5 best-selling tracks by revenue, I need to analyze the relevant tables that contain information about tracks and their sales. Typically, this would involve a "tracks" table that includes track details and a "sales" or "orders" table that records sales transactions.
My expected outcome is a list of the top 5 tracks sorted by total revenue generated from sales. The challenge here is to ensure that I correctly join the tables and aggregate the sales data to calculate the total revenue for each track.
I will structure the query to:
1. Join the "tracks" table with the "sales" table on the track ID.
2. Sum the revenue for each track.
3. Order the results by total revenue in descending order.
4. Limit the results to the top 5 tracks.
I will first check the database schema to confirm the table names and their relationships.
- Required tables: "tracks" and "sales" (or equivalent names).
- Important columns: Track ID, track name, and revenue from sales.
- Specific filters: None needed, but I will aggregate sales data.
- Proper ordering: By total revenue in descending order, limited to 5 results.
Now, I will check the database for the existing tables to confirm their names and structure.
Now that I have confirmed the relevant tables and their structures, I can proceed to construct the SQL query. The "Track" table contains information about each track, including its ID and price. The "InvoiceLine" table records each sale, linking to the "Track" table via the TrackId, and includes the quantity sold and unit price.
To calculate the total revenue for each track, I will:
1. Join the "Track" table with the "InvoiceLine" table on the TrackId.
2. Multiply the UnitPrice by the Quantity for each sale to get the revenue for that sale.
3. Sum the revenue for each track.
4. Order the results by total revenue in descending order.
5. Limit the results to the top 5 tracks.
This approach will ensure that I accurately capture the best-selling tracks by revenue.
- Required tables: "Track" and "InvoiceLine".
- Important columns: TrackId, Name (from Track), UnitPrice, Quantity (from InvoiceLine).
- Specific filters: None needed, as I want all tracks.
- Proper ordering: By total revenue in descending order, limited to 5 results.
Now, I will construct the SQL query based on this analysis.
- I included a clear reasoning section explaining the need for the query.
- I provided an analysis of the query structure, detailing the tables and columns involved.
- I executed the query and received results without errors.
The query successfully returned the top 5 best-selling tracks by revenue. Here are the results:
1. **The Woman King** - $3.98
2. **The Fix** - $3.98
3. **Walkabout** - $3.98
4. **Hot Girl** - $3.98
5. **Gay Witch Hunt** - $3.98
All tracks generated the same revenue, which indicates that they may have been sold in equal quantities or at the same price point.
Everything is in order, and I have verified all steps.
Based on my analysis, here's the SQL query that will answer your question:
SELECT
t.TrackId,
t.Name,
SUM(il.UnitPrice * il.Quantity) AS TotalRevenue
FROM
Track t
JOIN
InvoiceLine il ON t.TrackId = il.TrackId
GROUP BY
t.TrackId, t.Name
ORDER BY
TotalRevenue DESC
LIMIT 5;
As you can see, the reasoning is there and shows all the thinking steps. The output demonstrates how our agent thinks, showing its work every step of the way instead of just jumping straight to the answer. Feel free to adapt this approach for your own use cases!