LLM

LLM-on-Spark: Four Patterns That Actually Scale

"Just call the LLM in a loop." 9.6 years later, you finish. Here are the 4 patterns that actually scale to a billion rows: Spark UDFs, Ray+vLLM, warehouse-native SQL, or the Batch API. Code + costs.

Nithin K Anil Β· Scaling_LLMs_to_380_million_rows

TL;DR. You have a text column and a billion rows. There are exactly four sane ways to point an LLM at that data: Spark UDFs around a self-hosted model, Ray Data + vLLM for high-throughput GPU batching, warehouse-native SQL functions (ai_query, AI_COMPLETE, AI.GENERATE_TEXT), or external APIs from Spark — preferably the Batch API. Skip to Approach 3 if your data is already in Databricks/Snowflake/BigQuery. Skip to Approach 4c for any nightly job. The other two are for the cases those don't cover.


The 2 a.m. problem

Picture a fintech with 380 million historical chat-support messages sitting in S3. On Monday, compliance walks in: classify each message for PII risk by Friday, with a JSON output of {has_pii, pii_types[], confidence}, and audit trail.

A sensible engineer opens a notebook and writes:

for row in df.collect():
    result = openai.chat.completions.create(...)

At 800 ms per call, that loop finishes in 9.6 years. So you parallelise. The first parallel attempt hits the OpenAI rate limit in 90 seconds, the second hits OOM on the GPU, the third melts the budget. By Wednesday you've learned that "call an LLM in a loop" and "call an LLM 380 M times" are different engineering problems.

This post is about how to actually do the second one. Four patterns, with code, diagrams, rough cost-and-throughput numbers, and the trade-offs that matter at 2 a.m. on Thursday.

Why the naive loop dies

A single GPT-class request takes 200 ms to several seconds. At 1 second per row, 100 M rows is 3.2 years on one machine. Parallelism brings its own problems:

  • Hosted APIs throttle you within seconds of starting (RPM/TPM limits).
  • Self-hosted models load slowly and need careful GPU memory budgeting.
  • A 99.9% success rate on 100 M rows is still 100,000 failures.
  • Spark will retry tasks. If you don't make calls idempotent, you'll pay OpenAI twice for the same prompt.

Spark — or any modern distributed engine — already solves the orchestration half: partitioning, retries, idempotent writes, lineage. The four patterns below all use Spark (or its warehouse equivalent) as the orchestrator. They differ only in where the model lives.

The four patterns at a glance

# Pattern Where the model runs Best for Rough cost (per 1 M rows, ~200-token prompts)¹
1 Spark UDF + self-hosted Inside each Spark executor Privacy-sensitive, ≤13B model, 1–10 M rows ~$30–$80 (GPU rental)
2 Ray Data + vLLM Dedicated GPU pool, vLLM engine Open-weights, ≥10 M rows ~$10–$30 (GPU rental)
3 Warehouse-native SQL Inside Databricks / Snowflake / BigQuery Data already in warehouse ~$50–$200 (per-token, vendor markup)
4a Sync API from UDF Hosted (OpenAI/Anthropic/Vertex) Need answers in minutes ~$150 (gpt-4o-mini list price)
4c Batch API from Spark Hosted, async Any nightly job ~$75 (Batch API, 50% off)

¹ Order-of-magnitude only. Real numbers depend on prompt length, model choice, region, GPU type, and your negotiated rates. Use these as a starting point for a spreadsheet, not a quote.


Approach 1 — Spark UDF around a self-hosted model

Architecture

β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
β”‚ Spark driver                                             β”‚
β”‚   ─ partitions DataFrame                                 β”‚
β”‚   ─ schedules tasks                                      β”‚
β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜
        β”‚                β”‚                β”‚
        β–Ό                β–Ό                β–Ό
   β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”    β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”    β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
   β”‚Executor 1β”‚    β”‚Executor 2β”‚    β”‚Executor 3β”‚
   β”‚ β”Œβ”€β”€β”€β”€β”€β”€β” β”‚    β”‚ β”Œβ”€β”€β”€β”€β”€β”€β” β”‚    β”‚ β”Œβ”€β”€β”€β”€β”€β”€β” β”‚
   β”‚ β”‚Model β”‚ β”‚    β”‚ β”‚Model β”‚ β”‚    β”‚ β”‚Model β”‚ β”‚   ← one copy
   β”‚ β”‚(GPU) β”‚ β”‚    β”‚ β”‚(GPU) β”‚ β”‚    β”‚ β”‚(GPU) β”‚ β”‚     per worker
   β”‚ β””β”€β”€β”€β”€β”€β”€β”˜ β”‚    β”‚ β””β”€β”€β”€β”€β”€β”€β”˜ β”‚    β”‚ β””β”€β”€β”€β”€β”€β”€β”˜ β”‚
   β”‚ rows ──► β”‚    β”‚ rows ──► β”‚    β”‚ rows ──► β”‚
   β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜    β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜    β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜

The model loads once per Python worker, and each worker chews through its slice of the data. This is the bread-and-butter approach.

The wrong way (don't do this)

from pyspark.sql.functions import udf
from pyspark.sql.types import StringType

@udf(StringType())
def summarize(text):
    # BUG 1: model loads on every row (~20 seconds, ~2 GB)
    # BUG 2: one row at a time, GPU sits idle 95% of the time
    from transformers import pipeline
    pipe = pipeline("summarization", model="facebook/bart-large-cnn")
    return pipe(text)[0]["summary_text"]

df.withColumn("summary", summarize("body")).write.parquet("...")

This works in tutorials and dies in production. On 1 M rows you'd be reloading the model 1 M times — a few weeks of pure I/O.

The right way: pandas_udf with a module-level singleton

import pandas as pd
from pyspark.sql.functions import pandas_udf
from pyspark.sql.types import StringType

# Module-level cache. Each Python worker process has its own copy.
_pipe = None

def get_pipe():
    """Lazy-load the model once per worker, then reuse it."""
    global _pipe
    if _pipe is None:
        from transformers import pipeline
        import torch
        _pipe = pipeline(
            "summarization",
            model="facebook/bart-large-cnn",
            device=0 if torch.cuda.is_available() else -1,  # GPU if there
            batch_size=16,                                   # internal HF batching
        )
    return _pipe

@pandas_udf(StringType())
def summarize_batch(texts: pd.Series) -> pd.Series:
    """Receives a Series of strings, runs them through the model in one batch."""
    pipe = get_pipe()
    out = pipe(texts.tolist(), truncation=True, max_length=128)
    return pd.Series([o["summary_text"] for o in out])

# Each Arrow batch handed to the UDF is this large. Tune for your GPU.
spark.conf.set("spark.sql.execution.arrow.maxRecordsPerBatch", "512")

(df.withColumn("summary", summarize_batch("body"))
   .write.mode("overwrite").parquet("/lake/silver/reviews_summarised"))

The two cardinal rules: load once per worker, and make sure each call sees a real batch so the GPU isn't idle between rows. Both are documented in the Databricks pandas-UDF and batch-inference guides.

When you need finer control: mapInPandas

pandas_udf enforces "one Series in, one Series out, same length." For LLMs you sometimes want to flush a sub-batch on token budget rather than row count, or skip rows mid-batch. That's what mapInPandas is for:

from typing import Iterator
import pandas as pd

def summarize_partition(batches: Iterator[pd.DataFrame]) -> Iterator[pd.DataFrame]:
    """Receives an iterator of DataFrames, yields any DataFrames you like back."""
    pipe = get_pipe()  # same singleton from above
    for batch in batches:
        # You could split `batch` by token count here, drop bad rows, etc.
        outputs = pipe(batch["body"].tolist(), truncation=True, max_length=128)
        batch["summary"] = [o["summary_text"] for o in outputs]
        yield batch[["id", "summary"]]

result = df.mapInPandas(summarize_partition, schema="id long, summary string")

Spark 3.4+: predict_batch_udf

Spark 3.4 added predict_batch_udf, which formalises the "load once per worker, batch automatically" pattern:

from pyspark.ml.functions import predict_batch_udf
from pyspark.sql.types import ArrayType, FloatType
import numpy as np

def make_predict_fn():
    """Called once per executor. Returns a function that does inference."""
    from sentence_transformers import SentenceTransformer
    model = SentenceTransformer("BAAI/bge-small-en-v1.5", device="cuda")
    def predict(inputs: np.ndarray) -> np.ndarray:
        return model.encode(inputs.tolist(), batch_size=64,
                            normalize_embeddings=True)
    return predict

embed = predict_batch_udf(
    make_predict_fn,
    return_type=ArrayType(FloatType()),
    batch_size=64,
)

df.withColumn("embedding", embed("body")).write.parquet("...")

Numbers and trade-offs

A worked example on a small open model:

Workload Setup Throughput Cost (rough)
1 M reviews, BART-large summarisation, 4× g5.2xlarge (A10) on AWS pandas_udf + batch=16 ~2,000 rows/sec ~$45 (45 min × $1.21/hr × 4)
Same on CPU executors (m5.4xlarge × 4) pandas_udf + batch=8 ~80 rows/sec ~$140 (3.5 hr × $0.77 × 4)

What you get:

  • Full control: fine-tunes, quantisation, on-prem GPUs, no data leaves your VPC.
  • One unified Spark job; retries, partitioning and lineage work the way your platform team expects.
  • Predictable cost — you pay for GPU hours, not per token.

What hurts:

  • You're now operating an inference cluster.
  • Each Python worker loads its own copy of the model. For a 70 B model you have to set spark.task.cpus so only one task runs per executor, or you OOM.
  • Spark's task model gives the GPU a batch, waits, gives it the next batch — between tasks the GPU sits idle. Fine at small scale, wasteful at large scale. That's the gap Approach 2 fills.

Reach for this when the model fits comfortably on one GPU (≤13 B at FP16, larger if quantised), the data is sensitive, and the volume is millions, not billions.


Approach 2 — Ray Data + vLLM for high-throughput batching

Architecture

β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
β”‚ Spark / Ray driver         β”‚   reads input, hands batches to Ray Data
β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜
            β”‚
            β–Ό
β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
β”‚ Ray Data pipeline (map_batches)                        β”‚
β”‚                                                        β”‚
β”‚   β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”    β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”    β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”         β”‚
β”‚   β”‚ vLLM     β”‚    β”‚ vLLM     β”‚    β”‚ vLLM     β”‚         β”‚
β”‚   β”‚ engine 1 │◄──►│ engine 2 │◄──►│ engine 3 β”‚   ← continuous
β”‚   β”‚ (GPU)    β”‚    β”‚ (GPU)    β”‚    β”‚ (GPU)    β”‚     batching,
β”‚   β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜    β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜    β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜     paged KV cache
β”‚                                                        β”‚
β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜
            β”‚
            β–Ό
       output Parquet

Difference from Approach 1: vLLM keeps the GPU saturated by streaming requests in and out asynchronously, packing partially-finished sequences alongside fresh ones. Typical 5–20× throughput uplift on the same hardware.

Standalone Ray + vLLM

import ray
from ray.data.llm import vLLMEngineProcessorConfig, build_llm_processor

# One vLLM engine per "concurrency" slot. Each engine owns a GPU.
config = vLLMEngineProcessorConfig(
    model_source="meta-llama/Meta-Llama-3.1-8B-Instruct",
    engine_kwargs={
        "enable_chunked_prefill": True,    # better mixing of prefill + decode
        "max_num_batched_tokens": 4096,
        "max_model_len": 16384,
    },
    concurrency=4,    # 4 vLLM replicas, one per GPU
    batch_size=64,    # Ray Data batch size, NOT the GPU batch size
)

processor = build_llm_processor(
    config,
    # Turn each input row into a chat-style request.
    preprocess=lambda row: dict(
        messages=[{"role": "user",
                   "content": f"Summarise: {row['body']}"}],
        sampling_params=dict(temperature=0.0, max_tokens=128),
    ),
    # Turn the engine's output back into the columns you want.
    postprocess=lambda row: dict(
        id=row["id"],
        summary=row["generated_text"],
    ),
)

ds = ray.data.read_parquet("s3://lake/bronze/reviews/")
ds = processor(ds)
ds.write_parquet("s3://lake/silver/reviews_summarised/")

Ray on Spark — when you already live in Databricks

from ray.util.spark import setup_ray_cluster
import ray

setup_ray_cluster(
    max_worker_nodes=4,
    num_gpus_worker_node=1,
    num_cpus_worker_node=8,
)
ray.init()

# Hand the Spark DataFrame straight to Ray Data
ds = ray.data.from_spark(spark_df)
ds = processor(ds)              # same processor as above
ds.write_parquet("s3://lake/silver/reviews_summarised/")

You're running two cluster managers in the same process, which isn't beautiful, but it's the standard way the Databricks ecosystem reaches vLLM-class throughput today.

Numbers and trade-offs

Workload Setup Throughput Cost (rough)
10 M prompts, Llama-3.1-8B, 4× A10 GPUs (g5.2xlarge) Ray Data + vLLM ~12,000 prompts/sec ~$30
Same workload, same hardware, pandas_udf + HF Transformers Approach 1 ~1,500 prompts/sec ~$240

What you get:

  • Highest tokens-per-dollar of any self-hosted approach.
  • Scales 1 → N GPUs with the same code; vLLM handles tensor parallel across GPUs.
  • Async-friendly: failed prompts can be retried within the engine without restarting the Spark stage.

What hurts:

  • More moving parts: Ray + vLLM + Spark, three runtimes to debug.
  • vLLM eats GPU memory aggressively (KV cache pre-allocation). Cohabiting with other workloads is finicky.
  • Resilience is still maturing — a single vLLM engine dying at high concurrency can take down the whole job.
  • Overkill for a few million rows; you'll spend more time on ops than on inference.

Reach for this when you have ≥10 M prompts per run on open-weights models and inefficient GPU utilisation costs real money.


Approach 3 — Warehouse-native LLM SQL

Architecture

β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
β”‚ Data warehouse (Databricks / Snowflake / BigQuery)     β”‚
β”‚                                                        β”‚
β”‚   β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”                  β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”  β”‚
β”‚   β”‚ table.body  β”‚ ──ai_query()──►  β”‚  vendor LLM    β”‚  β”‚
β”‚   β”‚             β”‚                  β”‚  pool (managed β”‚  β”‚
β”‚   β”‚ table.label β”‚ ◄───────────────  β”‚  by warehouse) β”‚  β”‚
β”‚   β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜                  β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜  β”‚
β”‚                                                        β”‚
β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜
   No data egress. No GPU ops. No Python.

Newest pattern, and for many teams the simplest: don't move the data to the model. Call an LLM as a SQL function and let the warehouse handle parallelism, retries, and throughput.

Databricks: ai_query

SELECT
  id,
  body,
  ai_query(
    'databricks-meta-llama-3-3-70b-instruct',
    'Summarise this review in one sentence: ' || body,
    modelParameters => named_struct('max_tokens', 60, 'temperature', 0.0),
    failOnError      => false      -- critical for big batches
  ) AS summary
FROM gold.reviews
WHERE country = 'US';

Without failOnError => false, one bad row aborts your whole 100 M-row query. With it, failures land in a errorMessage field and the rest of the job completes.

Snowflake: AI_COMPLETE

SELECT
  id,
  body,
  SNOWFLAKE.CORTEX.AI_COMPLETE(
    model  => 'snowflake-llama-3.3-70b',
    prompt => 'Classify the sentiment of this review as POSITIVE, '
              || 'NEUTRAL or NEGATIVE: ' || body
  ) AS sentiment
FROM raw.reviews;

Cortex AISQL went GA in November 2025. Models are billed per-token off your existing Snowflake credits, and the function runs inside the warehouse with zero data egress.

BigQuery: AI.GENERATE_TEXT

-- One-time: register the Vertex/Gemini endpoint as a model
CREATE OR REPLACE MODEL `analytics.gemini_flash`
REMOTE WITH CONNECTION `us.vertex_ai`
OPTIONS (ENDPOINT = 'gemini-2.5-flash');

-- Then call it from SQL
SELECT *
FROM AI.GENERATE_TEXT(
  MODEL `analytics.gemini_flash`,
  (SELECT id, CONCAT('Extract product names as JSON: ', body) AS prompt
     FROM raw.reviews),
  STRUCT(0.0 AS temperature, 256 AS max_output_tokens)
);

Numbers and trade-offs

Workload Setup Throughput Cost (rough)
1 M short prompts, Llama-3.3-70B via Databricks ai_query one SQL query ~3,000 rows/sec ~$70–$120
1 M short prompts, Gemini 2.5 Flash via BigQuery one SQL query ~1,500 rows/sec ~$50–$80

What you get:

  • Fastest path from "I have a text column" to "I have an enriched column." Hours, not weeks.
  • Zero data egress. Governance, RBAC, audit, lineage are exactly what your warehouse already enforces.
  • The platform handles parallelism, retries, throughput allocation. You stop being an inference operator.

What hurts:

  • Model menu is whatever the vendor offers.
  • Per-token pricing is usually a premium over calling the underlying API directly.
  • Less flexibility on prompt orchestration: chained prompts, tool use, multi-turn agents are clunky-to-impossible inside SQL.
  • Cross-cloud is rough. BigQuery models stay in BigQuery's region, etc.

Reach for this when the data already lives in a modern warehouse, the use case is "enrich a column," and the team would rather spend on credits than on operating a GPU cluster.


Approach 4 — External LLM APIs from Spark

If you want a frontier model (GPT-class, Claude-class, Gemini Pro) you don't have a self-hosted option. So you call the API from Spark.

Architecture

β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”    repartition(N)    β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
β”‚ Spark exec β”‚ ────────────────────►│ async    β”‚ ──► hosted LLM API
β”‚ (CPU only) β”‚      β–²               β”‚ HTTP     β”‚     (OpenAI/Anthropic
β”‚            β”‚      β”‚ rate limit    β”‚ client   β”‚      /Bedrock/Vertex)
β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜      β”‚ throttle      β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜
                    β”‚
            tune partition count
            ≈ RPM ÷ 60 = concurrency

4a. Synchronous calls with backoff

Lowest per-row latency, but you're at the mercy of the rate limit.

import time, random
from openai import OpenAI, RateLimitError
import pandas as pd
from pyspark.sql.functions import pandas_udf
from pyspark.sql.types import StringType

# One client per worker, lazy-loaded (same singleton trick as Approach 1).
_client = None
def client():
    global _client
    if _client is None:
        _client = OpenAI()
    return _client

def call_with_backoff(prompt: str, attempts: int = 6) -> str:
    """Retry with exponential backoff. Capture failures as values, not exceptions."""
    for i in range(attempts):
        try:
            r = client().chat.completions.create(
                model="gpt-4o-mini",
                messages=[{"role": "user", "content": prompt}],
                max_tokens=200,
                temperature=0,
            )
            return r.choices[0].message.content
        except RateLimitError:
            time.sleep(min((2 ** i) + random.random(), 60))
        except Exception as e:
            return f"__ERROR__:{type(e).__name__}"
    return "__ERROR__:rate_limit_exhausted"

@pandas_udf(StringType())
def classify(prompts: pd.Series) -> pd.Series:
    return prompts.map(call_with_backoff)

(df
  .repartition(64)            # cap concurrency to ~64 in-flight calls
  .withColumn("label", classify("body"))
  .write.mode("overwrite").parquet("/lake/silver/reviews_classified"))

Things you'll wish you'd known on day one:

  • repartition(N) is the simplest concurrency throttle. Set N ≈ RPM ÷ 60.
  • Capture errors as values; one rate-limit storm should not fail your whole stage.
  • Write idempotently. Spark will retry tasks, and you don't want to pay twice.
  • In real code use tenacity for the backoff loop, not hand-rolled.

4b. Async fan-out within each partition

10–30× the throughput of 4a, when your rate limit allows:

import asyncio, os, aiohttp
from typing import Iterator
import pandas as pd

async def one(session, sem, prompt):
    """One in-flight HTTP call, gated by a semaphore."""
    async with sem:
        async with session.post(
            "https://api.openai.com/v1/chat/completions",
            json={"model": "gpt-4o-mini",
                  "messages": [{"role": "user", "content": prompt}],
                  "max_tokens": 200, "temperature": 0},
            timeout=aiohttp.ClientTimeout(total=60),
        ) as r:
            data = await r.json()
            return data["choices"][0]["message"]["content"]

async def run(prompts, max_inflight=32):
    sem = asyncio.Semaphore(max_inflight)
    headers = {"Authorization": f"Bearer {os.environ['OPENAI_API_KEY']}"}
    async with aiohttp.ClientSession(headers=headers) as session:
        return await asyncio.gather(*[one(session, sem, p) for p in prompts])

def classify_partition(batches: Iterator[pd.DataFrame]) -> Iterator[pd.DataFrame]:
    """Fan out async per batch, then yield results back to Spark."""
    loop = asyncio.new_event_loop()
    for batch in batches:
        results = loop.run_until_complete(run(batch["body"].tolist()))
        batch["label"] = results
        yield batch[["id", "label"]]

df.mapInPandas(classify_partition, schema="id long, label string").write.parquet("...")

Always sanity-check max_inflight × num_executors against your published rate limit, or you'll throttle yourself.

4c. Batch API — the right answer for nightly jobs

If you can wait up to 24 hours, OpenAI / Anthropic / Gemini / Bedrock Batch APIs give you a 50% discount, run against a separate rate-limit pool that doesn't compete with live traffic, and accept up to 50,000 requests / 200 MB per submission.

The Spark job becomes two stages:

Stage 1 (Spark, now)         Stage 2 (Spark, +N hours)
─────────────────────        ──────────────────────────
read input                   read ledger of batch_ids
build JSONL per partition    poll client.batches.retrieve()
upload as files              when done, download outputs
submit batch jobs            join back on custom_id
write (batch_id, ids) ledger write enriched table

A sketch of stage 1:

import json, io, pandas as pd
from openai import OpenAI

def submit_partition(batches):
    """Per-partition: write JSONL, upload, submit batch, emit ledger row."""
    client = OpenAI()
    for batch in batches:
        buf = io.BytesIO()
        for _, row in batch.iterrows():
            buf.write((json.dumps({
                "custom_id": str(row["id"]),
                "method": "POST",
                "url": "/v1/chat/completions",
                "body": {
                    "model": "gpt-4o-mini",
                    "messages": [{"role": "user",
                                  "content": f"Summarise: {row['body']}"}],
                    "max_tokens": 128,
                },
            }) + "\n").encode())
        buf.seek(0)
        buf.name = "requests.jsonl"          # OpenAI SDK reads this
        f = client.files.create(file=buf, purpose="batch")
        b = client.batches.create(
            input_file_id=f.id,
            endpoint="/v1/chat/completions",
            completion_window="24h",
        )
        yield pd.DataFrame([{"batch_id": b.id, "n_rows": len(batch)}])

A second scheduled Spark job picks up the ledger, polls for completion, downloads output_file_id, and joins by custom_id.

Numbers and trade-offs

Workload Setup Throughput Cost (gpt-4o-mini list price)
1 M prompts, 200-token in / 60-token out, sync UDF Approach 4a, 64 partitions ~50 rows/sec (rate-limited) ~$150
Same, async fan-out Approach 4b, 32 in-flight × 64 partitions ~800 rows/sec ~$150
Same, Batch API Approach 4c, finishes in ≤24h n/a (async) ~$75

What you get:

  • Frontier model quality.
  • Zero GPU ops on your side.
  • Batch API is cheap and politely sidesteps your live RPM cap.

What hurts:

  • Data leaves your perimeter. For regulated workloads this can be a non-starter without Bedrock / Azure OpenAI / Vertex setup.
  • Token cost can balloon — a careless prompt template multiplies your bill 10×.
  • Sync paths depend on provider availability.
  • Idempotency is on you.

Reach for the sync path when latency matters and volume is small/medium. Reach for the Batch API for any nightly enrichment, eval run, or backfill — it's almost always the right answer.


The piece most LLM-pipeline posts skip: structured output

In production, "give me a string" is rarely what you want. You want {"sentiment": "NEGATIVE", "themes": ["wifi", "breakfast"], "confidence": 0.92} — a typed object you can write straight into a Parquet column and query. A model that almost always returns valid JSON is a 4 a.m. pager.

Three ways to lock the output shape, in increasing order of safety:

1. Just ask. Cheap, unreliable. Tell the model "Reply with JSON only." It works ~95% of the time. The other 5% emit prose, leading whitespace, code fences, or invented keys. Always wrap with try/except json.JSONDecodeError and capture failures as __BAD_JSON__ rows.

2. JSON mode / response_format. OpenAI, Anthropic, Gemini all support a "guaranteed JSON" mode. The output is valid JSON, but the schema is still up to the prompt. Closes maybe 90% of the failure modes.

3. Schema-constrained decoding. This is the one to actually use. Define a Pydantic model and let the API enforce it at the token level — invalid tokens are masked out during sampling, so the output is structurally guaranteed to match.

from pydantic import BaseModel
from typing import Literal
from openai import OpenAI

class PIIRisk(BaseModel):
    has_pii: bool
    pii_types: list[Literal["email", "phone", "ssn", "card", "address"]]
    confidence: float

client = OpenAI()
result = client.chat.completions.parse(
    model="gpt-4o-mini",
    messages=[{"role": "user",
               "content": f"Extract PII info from: {row['body']}"}],
    response_format=PIIRisk,         # ← schema-constrained decoding
)
parsed: PIIRisk = result.choices[0].message.parsed

For self-hosted models, Outlines and vLLM's guided_json give you the same guarantee:

# Inside the vLLM Ray Data processor
sampling_params = dict(
    temperature=0.0,
    max_tokens=200,
    guided_json=PIIRisk.model_json_schema(),
)

For warehouse-native SQL, the answer is to validate after: cast the JSON, check required fields, reject malformed rows into a quarantine table.

SELECT
  id,
  TRY_PARSE_JSON(ai_response) AS parsed,
  CASE WHEN TRY_PARSE_JSON(ai_response) IS NULL
       THEN 'BAD_JSON' END AS error
FROM raw_results

Once your output is typed, the rest of the pipeline (joins, aggregates, dashboards) gets dramatically simpler.


Observability: how you know it's working

A 99.9% success rate on 100 M rows is 100,000 wrong answers shipped to production. Three things to put in place from day one.

1. An eval set, in git. Before the first run, hand-label 100–500 rows (the ones humans can confidently judge). Re-score the model against this set every time you change the prompt, model, or temperature. If accuracy drops, you find out before your users do.

# evals/test_pii_classifier.py
def test_pii_classifier_accuracy():
    eval_set = pd.read_csv("evals/pii_golden.csv")  # text, expected_has_pii
    preds = classify_batch(eval_set["text"].tolist())
    acc = (preds == eval_set["expected_has_pii"]).mean()
    assert acc > 0.92, f"Accuracy regressed to {acc:.2%}"

2. A 1% sample for human review. Each run, write a sample of (input, output, model_version, prompt_version) to a predictions_sample table. Have someone glance at 50 rows once a week. You'll catch silent prompt drift, a model deprecation, and the time someone changed temperature to 0.9 by accident.

3. Output-distribution monitoring. If yesterday's classifier said 12% of tickets were "high risk" and today's says 47%, something has changed even if accuracy on the eval set is fine. Track simple stats per run:

SELECT
  run_date,
  COUNT(*)                                   AS n,
  AVG(CASE WHEN risk = 'high' THEN 1 ELSE 0 END) AS pct_high,
  AVG(confidence)                            AS avg_conf,
  AVG(CASE WHEN response LIKE '__ERROR__%'
           OR response LIKE '__BAD_JSON__%'
           THEN 1 ELSE 0 END)                AS pct_failed
FROM enriched_tickets
GROUP BY run_date
ORDER BY run_date;

Throw a chart on a dashboard. Alert on > 2σ moves. This single SQL query has caught more LLM regressions than any fancy framework.


Decision framework

In rough order of "what to try first":

Situation Start with
Data lives in Databricks / Snowflake / BigQuery, simple "enrich a column" Approach 3
Frontier model required, can wait hours, offline job Approach 4c (Batch API)
Frontier model required, need same-hour results Approach 4a or 4b
Self-host required, small/medium volume, ≤13 B model Approach 1
Self-host required, very high volume (>10 M rows) Approach 2

A real pipeline often combines these: a cheap embedding-based filter via predict_batch_udf to cut the row count, then a frontier model via Batch API on the survivors, with a warehouse-native function for the easy cases.

Rules of thumb regardless of approach

  • Load the model once per worker. Module-level singleton or predict_batch_udf's make_predict_fn.
  • Set spark.sql.execution.arrow.maxRecordsPerBatch deliberately. The default of 10k is rarely right for LLMs.
  • Make the job idempotent. Hash the prompt, write to a deterministic path, or maintain a ledger.
  • Capture errors as data, not exceptions. __ERROR__:RateLimit in a column beats a stack trace in driver logs.
  • Cap concurrency to your rate limit. df.repartition(N) for sync; semaphore for async.
  • Lock the output shape. Pydantic + response_format for hosted, Outlines/guided_json for self-hosted, post-hoc validation for SQL.
  • Ship an eval set, a 1% sample, and a distribution dashboard before you ship the pipeline.
  • Cache aggressively. (input_hash, model, prompt_version, output) makes reruns free joins.
  • Prefer Batch APIs for offline work. The 50% discount and separate quota pool change the economics of "rerun the table next week."

TL;DR, again

Spark is the orchestrator, the LLM is a function, and the four patterns are four places that function can live: on the Spark executor, on a Ray-managed GPU pool, inside the warehouse, or behind an HTTP API. Pick the one that matches your data gravity and your budget; lock the output schema; ship the eval set with the pipeline. The hard part isn't getting the model to talk. It's making the talking auditable, idempotent, and affordable at a billion rows.


Further reading

Comments (0)

?

Leave a comment