import asyncio
import uuid
from dataclasses import asdict
from textwrap import dedent
from typing import List
import art
import weave
from langchain_core.messages import HumanMessage, SystemMessage
from langchain_core.tools import tool
from langgraph.prebuilt import create_react_agent
from litellm import acompletion
from pydantic import BaseModel, Field
from tenacity import retry, stop_after_attempt
from art.langgraph import init_chat_model, wrap_rollout
from art.utils import iterate_dataset
# Initialize model and backend
model = art.Model(name="Qwen/Qwen2.5-7B-Instruct")
backend = art.backends.SkyPilotBackend()
# Data models
class EmailResult(BaseModel):
message_id: str
subject: str
from_address: str
date: str
snippet: str
class FinalAnswer(BaseModel):
answer: str
source_ids: List[str]
class Scenario(BaseModel):
id: str
question: str
answer: str
inbox_address: str
query_date: str
class EmailScenario(BaseModel):
step: int
scenario: Scenario
class ProjectTrajectory(art.Trajectory):
final_answer: FinalAnswer | None = None
class CorrectnessJudgeResponse(BaseModel):
reasoning: str = Field(description="Explanation of the reasoning process.")
accept: bool = Field(description="Whether the AI answer should be accepted.")
# Mock email functions (replace with real implementation)
def search_emails(inbox: str, keywords: List[str], sent_before: str) -> List[EmailResult]:
"""Mock email search function - replace with real implementation"""
return [
EmailResult(
message_id="msg_123",
subject=f"Subject matching {keywords[0]}",
from_address="sender@example.com",
date="2024-01-15",
snippet=f"Email snippet containing {keywords[0]}"
)
]
def read_email(message_id: str) -> EmailResult | None:
"""Mock email read function - replace with real implementation"""
return EmailResult(
message_id=message_id,
subject="Full email subject",
from_address="sender@example.com",
date="2024-01-15",
snippet="Full email content here..."
)
# Correctness evaluation
@retry(stop=stop_after_attempt(3))
async def judge_correctness(scenario: Scenario, answer: str) -> CorrectnessJudgeResponse:
system_prompt = dedent("""
You are given a question, the reference answer, and an answer generated by an AI assistant.
Your task is to decide whether the AI answer is correct and should be accepted.
""")
messages = [
{"role": "system", "content": system_prompt},
{
"role": "user",
"content": (
f"Question: {scenario.question}\n"
f"Reference answer: {scenario.answer}\n"
f"AI answer: {answer}"
),
},
]
response = await acompletion(
model="openai/gpt-4o-mini",
messages=messages,
response_format=CorrectnessJudgeResponse,
)
return CorrectnessJudgeResponse.model_validate_json(
response.choices[0].message.content or "{}"
)
# Main rollout function
@weave.op
async def rollout(model: art.Model, email_scenario: EmailScenario) -> ProjectTrajectory:
scenario = email_scenario.scenario
MAX_TURNS = 10
traj = ProjectTrajectory(
reward=0.0,
messages_and_choices=[],
metadata={
"scenario_id": scenario.id,
"step": email_scenario.step,
},
)
system_prompt = dedent(f"""
You are an email search agent. Use the tools to search emails and find answers.
User's email address: {scenario.inbox_address}
Today's date: {scenario.query_date}
When you find the answer, use return_final_answer_tool with the answer and source message IDs.
""")
final_answer = None
@tool
def search_inbox_tool(keywords: List[str]) -> List[dict]:
"""Search inbox for emails matching keywords"""
results = search_emails(scenario.inbox_address, keywords, scenario.query_date)
return [asdict(result) for result in results]
@tool
def read_email_tool(message_id: str) -> dict | None:
"""Read a specific email by message ID"""
email = read_email(message_id)
return email.model_dump() if email else None
@tool
def return_final_answer_tool(answer: str, reference_message_ids: List[str]) -> dict:
"""Return final answer with source message IDs"""
nonlocal final_answer
final_answer = FinalAnswer(answer=answer, source_ids=reference_message_ids)
return final_answer.model_dump()
tools = [search_inbox_tool, read_email_tool, return_final_answer_tool]
chat_model = init_chat_model(model.name, temperature=1.0)
react_agent = create_react_agent(chat_model, tools)
try:
config = {
"configurable": {"thread_id": str(uuid.uuid4())},
"recursion_limit": MAX_TURNS,
}
await react_agent.ainvoke({
"messages": [
SystemMessage(content=system_prompt),
HumanMessage(content=scenario.question),
]
}, config=config)
if final_answer:
traj.final_answer = final_answer
correctness_judge_response = await judge_correctness(scenario, final_answer.answer)
traj.metrics["correct"] = float(correctness_judge_response.accept)
except Exception as e:
print(f"Error running agent: {e}")
traj.messages_and_choices.append({"role": "assistant", "content": f"Error: {str(e)}"})
return traj
# Main training function
async def main():
# Sample training scenarios (replace with real data)
training_scenarios = [
Scenario(
id="1",
question="Find emails about the quarterly budget",
answer="Budget meeting scheduled for Q4 review",
inbox_address="user@company.com",
query_date="2024-01-20"
),
Scenario(
id="2",
question="Look for urgent project updates",
answer="Project deadline moved to next month",
inbox_address="user@company.com",
query_date="2024-01-20"
),
]
# Register model with backend
await model.register(backend)
# Training configuration
training_config = {
"groups_per_step": 2,
"num_epochs": 3,
"rollouts_per_group": 4,
"learning_rate": 1e-5,
"max_steps": 5,
}
# Training iterator
training_iterator = iterate_dataset(
training_scenarios,
groups_per_step=training_config["groups_per_step"],
num_epochs=training_config["num_epochs"],
initial_step=await model.get_step(),
)
# Training loop
for batch in training_iterator:
print(f"Training step {batch.step}, epoch {batch.epoch}")
# Create trajectory groups
groups = []
for scenario in batch.items:
groups.append(
art.TrajectoryGroup([
wrap_rollout(model, rollout)(
model, EmailScenario(step=batch.step, scenario=scenario)
)
for _ in range(training_config["rollouts_per_group"])
])
)
# Gather trajectories
finished_groups = await art.gather_trajectory_groups(
groups,
pbar_desc="gather",
max_exceptions=training_config["rollouts_per_group"] * len(batch.items),
)
# Train model
await model.train(
finished_groups,
config=art.TrainConfig(learning_rate=training_config["learning_rate"]),
)
print(f"Completed training step {batch.step}")
if batch.step >= training_config["max_steps"]:
break
if __name__ == "__main__":
asyncio.run(main())