Let's work through a Text2SQL use case where we are starting from scratch without a nice and clean dataset of questions, SQL queries, or expected responses.
Here are the steps:
Implement text2SQL
Setup dataset and evaluators
Run experiments
Implement Text2SQL
We are going to use the NBA dataset that information from 2014 - 2018 about every game played in that span. We will use DuckDB as our database.
import duckdb
from datasets import load_dataset
data = load_dataset("suzyanil/nba-data")["train"]
conn = duckdb.connect(database=":memory:", read_only=False)
conn.register("nba", data.to_pandas())
conn.query("SELECT * FROM nba LIMIT 5").to_df().to_dict(orient="records")[0]
Here's the example of one row of a game in the dataset
Let's start by implementing a simple text2sql logic.
import os
import openai
client = openai.AsyncClient()
columns = conn.query("DESCRIBE nba").to_df().to_dict(orient="records")
# We will use GPT4o to start
TASK_MODEL = "gpt-4o"
CONFIG = {"model": TASK_MODEL}
system_prompt = (
"You are a SQL expert, and you are given a single table named nba with the following columns:\n"
f"{",".join(column["column_name"] + ": " + column["column_type"] for column in columns)}\n"
"Write a SQL query corresponding to the user's request. Return just the query text, "
"with no formatting (backticks, markdown, etc.)."
)
async def generate_query(input):
response = await client.chat.completions.create(
model=TASK_MODEL,
temperature=0,
messages=[
{
"role": "system",
"content": system_prompt,
},
{
"role": "user",
"content": input,
},
],
)
return response.choices[0].message.content
query = await generate_query("Who won the most games?")
print(query)
To setup an experiment we need a dataset, task and evaluator. Let's setup each.
Setup dataset
questions = [
"Which team won the most games?",
"Which team won the most games in 2015?",
"Who led the league in 3 point shots?",
"Which team had the biggest difference in records across two consecutive years?",
"What is the average number of free throws per year?",
]
Let's store the data above as a versioned dataset in Arize.
arize_client = ArizeDatasetsClient(
developer_key=os.environ.get("ARIZE_DEVELOPER_KEY"),
api_key=os.environ.get("ARIZE_API_KEY"),
)
# Create a dataset from a DataFrame add your own data here
test_df = pd.DataFrame([{"question": question} for question in questions])
dataset_id = arize_client.create_dataset(
space_id=space_id,
dataset_name=dataset_name,
dataset_type=GENERATIVE,
data=test_df,
)
dataset = arize_client.get_dataset(space_id=space_id, dataset_id=dataset_id)
dataset.head()
Setup task
Next, we'll define the task. The task is to generate SQL queries from natural language questions.
Finally, we'll define the evaluator. We'll use the following simple scoring functions to see if the generated SQL queries are correct.
# Test if there are no sql execution errors
def no_error(output):
output = json.loads(output)
return 1.0 if output.get("error") is None else 0.0
# Test if the query has results
def has_results(output):
output = json.loads(output)
results = output.get("results")
has_results = results is not None and len(results) > 0
return 1.0 if has_results else 0.0
Now that we ran the initial evaluation, it looks like three of the results are valid, one produces SQL errors, and one has no results.
The second query for `Which team won the most games in 2015` looks for Date LIKE '2015%' which is not correct. The fourth query does not have TEAM in the group by clause.
Let's try to improve the prompt with few-shot examples and see if we can get better results.
samples = conn.query("SELECT * FROM nba LIMIT 1").to_df().to_dict(orient="records")[0]
sample_rows = "\n".join(
f"{column['column_name']} | {column['column_type']} | {samples[column['column_name']]}"
for column in columns
)
system_prompt = (
"You are a SQL expert, and you are given a single table named nba with the following columns:\n\n"
"Column | Type | Example\n"
"-------|------|--------\n"
f"{sample_rows}\n"
"\n"
"Write a DuckDB SQL query corresponding to the user's request. "
"Return just the query text, with no formatting (backticks, markdown, etc.)."
)
async def generate_query(input):
response = await client.chat.completions.create(
model=TASK_MODEL,
temperature=0,
messages=[
{
"role": "system",
"content": system_prompt,
},
{
"role": "user",
"content": input,
},
],
)
return response.choices[0].message.content
print(await generate_query("Which team won the most games in 2015?"))
SELECT Team, COUNT(*) AS Wins FROM nba WHERE WINorLOSS = 'W' AND Date LIKE '%/15' GROUP BY Team ORDER BY Wins DESC LIMIT 1;
Looking better! Finally, let's add a scoring function that compares the results, if they exist, with the expected results. And then we can run this as another experiment and compare the results.
from phoenix.evals.models import OpenAIModel
from phoenix.evals.classify import llm_classify
from arize.experimental.datasets.experiments.types import EvaluationResult
IS_SQL_EVAL_TEMPLATE = """You are a SQL expert, is the following a valid SQL query that executes without errors? Return the single workd "valid" if is valid, and "invalid" if it is not.
[BEGIN SQL QUERY]
{query}
[END SQL QUERY]
"""
def check_is_sql(output):
output = json.loads(output)
query = output.get("query")
df_in = pd.DataFrame({"query": query}, index=[0]) if query else None
eval_df = llm_classify(
dataframe=df_in,
template=IS_SQL_EVAL_TEMPLATE,
model=OpenAIModel(model="gpt-4o"),
rails=["valid", "invalid"],
provide_explanation=True,
)
# return score, label, explanation
return EvaluationResult(
score=1,
label=eval_df["label"][0],
explanation=eval_df["explanation"][0],
)
experiment = arize_client.run_experiment(
space_id=space_id,
dataset_id=dataset_id,
task=task,
evaluators=[no_error, has_results, check_is_sql],
experiment_name="text2sql_test_new_prompt_and_eval-6",
)
You can see that the newer SQL has improved some cases, but there are still some other errors to iron out. As you experiment with different models, prompts, and techniques, you can continuously optimize your applications until they reach the performance thresholds you want.