Grok#

GrokAgent inherits all inference methods from ChatGPTAgent and overrides the batch methods to use xai_sdk, which exposes xAI’s native counter-based state model (num_pending == 0 signals completion).

Agent#

Utilities#

Examples#

Inference#

import asyncio

from core_genai.agents.grok.agent import GrokAgent
from core_genai.interfaces import IAgent
from core_mixins.decorators.async_ import SyncWrapper
from environs import Env

env = Env()
env.read_env(".env")

agent: GrokAgent = IAgent.create_agent("GrokAgent", api_key=env("XAI_API_KEY"))

MODEL = "grok-3-mini"
PROMPT = [{"role": "user", "content": "Explain how AI works in a few words"}]

async def run_async() -> None:
    output = await agent.analyze(model=MODEL, prompt=PROMPT)
    print("[async]", agent.get_text(output))
    print("[sync]", agent.get_text(output))
    print("[sync]", agent.get_metadata(output))

def run_sync() -> None:
    with SyncWrapper(agent) as sync_agent:
        output = sync_agent.analyze(model=MODEL, prompt=PROMPT)
        print("[sync]", agent.get_text(output))
        print("[sync]", agent.get_metadata(output))

if __name__ == "__main__":
    asyncio.run(run_async())
    run_sync()

Batch inference#

Pass a list of BatchRequest objects with prompt and optional custom_id. The agent uploads a JSONL file via the OpenAI-compatible files API, then creates the batch via xai_sdk. Poll until the state is "ended" (num_pending == 0), then retrieve results.

import asyncio

from core_genai.agents.chatgpt.agent import BatchRequest
from core_genai.agents.grok.agent import GrokAgent
from core_genai.interfaces import IAgent
from environs import Env

env = Env()
env.read_env(".env")

MODEL = "grok-3-beta"
POLL_INTERVAL = 30  # seconds between status checks

PROMPTS = [
    "Explain what machine learning is in one sentence.",
    "Explain what a neural network is in one sentence.",
    "Explain what reinforcement learning is in one sentence.",
]

agent: GrokAgent = IAgent.create_agent("GrokAgent", api_key=env("XAI_API_KEY"))

def build_requests() -> list[BatchRequest]:
    return [
        BatchRequest(
            custom_id=f"request-{i}",
            prompt=[{"role": "user", "content": prompt}],
        )
        for i, prompt in enumerate(PROMPTS)
    ]

async def poll_until_done(batch_id: str) -> str:
    while True:
        state = await agent.check_job_status(batch_id)
        print(f"  status: {state}")

        if state in GrokAgent.TERMINAL_STATES:
            return state

        await asyncio.sleep(POLL_INTERVAL)

def print_results(result: dict) -> None:
    responses = result["results"] or []
    print(f"{len(responses)} response(s):\n")
    for resp in responses:
        i = int(resp.batch_request_id.split("-")[-1])
        text = agent.get_text(resp.response) if resp.is_success else f"[error: {resp.error_message}]"
        print(f"[{i}] {PROMPTS[i]!r}")
        print(f"    {text}\n")

async def main() -> None:
    requests = build_requests()
    print(f"Scheduling batch job with {len(requests)} request(s)...")

    job = await agent.schedule_job(requests=requests, model=MODEL)
    batch_id = job["job_id"]
    print(f"Job scheduled: {batch_id}  (name: {job.get('batch_name')})")
    print("Polling for completion...")
    final_state = await poll_until_done(batch_id)
    print(f"Job finished with state: {final_state}")
    print_results(await agent.extract_job_results(batch_id))

async def extract_results(batch_id: str) -> None:
    print_results(await agent.extract_job_results(batch_id))

async def check_batches(*batch_ids: str) -> None:
    for batch_id in batch_ids:
        state = await agent.check_job_status(batch_id)
        print(f"{batch_id}: {state}")

if __name__ == "__main__":
    asyncio.run(main())

    asyncio.run(check_batches(
        "batch_...",
        "batch_...",
    ))

    asyncio.run(extract_results("batch_..."))