import asyncio
from datetime import datetime
import art
from art.serverless.backend import ServerlessBackend
from dotenv import load_dotenv
from envs.echo_env import EchoAction, EchoEnv
import weave
PROMPT = "Use at most 100 tokens; maximize the total character length of the output."
NUM_STEPS = 50
ROLLOUTS_PER_GROUP = 4
# The rollout function defines how your agent interacts with the environment
async def rollout(model: art.TrainableModel, env_client: EchoEnv) -> art.Trajectory:
# Reset the environment to get initial state
await asyncio.to_thread(env_client.reset)
# Create a trajectory to store interactions and rewards
traj = art.Trajectory(
messages_and_choices=[{"role": "system", "content": PROMPT}],
reward=0.0
)
# Use the model to generate an action
choice = (
await model.openai_client().chat.completions.create(
model=model.inference_model_name,
messages=traj.messages(),
max_completion_tokens=100,
timeout=30,
)
).choices[0]
reply = (choice.message.content or "").strip()
# Send the action to the environment and get observation/reward
result = await asyncio.to_thread(
env_client.step,
EchoAction(message=reply)
)
# Record the model's output and reward
traj.messages_and_choices.append(choice)
traj.reward = result.reward
return traj.finish()
async def main() -> None:
load_dotenv()
weave.init("openenv-demo")
# Set up the training backend
backend = ServerlessBackend()
# Define the model to train
model = art.TrainableModel(
name=f"openenv-echo-{datetime.now().strftime('%Y-%m-%d-%H%M%S')}",
project="openenv-demo",
base_model="OpenPipe/Qwen3-14B-Instruct",
)
await model.register(backend)
# Create a pool of environment clients for efficient training
env_pool = [
EchoEnv.from_docker_image("quixote13/echo-env:latest")
for _ in range(ROLLOUTS_PER_GROUP)
]
# Training loop
for step in range(await model.get_step(), NUM_STEPS):
print(f"Gathering groups for step {step}")
# Run multiple rollouts in parallel
groups = await art.gather_trajectory_groups([
art.TrajectoryGroup(
rollout(model, env_client)
for env_client in env_pool
)
])
# Train the model on collected trajectories
await model.train(groups)
if __name__ == "__main__":
asyncio.run(main())