Any way to streaming-preprocess a dataset to disk?

Hello!
I’m training a rather small image classification model (1m params) on a rather large HF dataset (~100gb train split, Andron00e/Places365-custom · Datasets at Hugging Face ). Right now, the bottleneck on my system is the transform stack, which applies image augmentations and is CPU-bound. I want to “remux” the dataset to a copy with the augmentations and other preprocessing steps baked in, so that preprocessing is a one-time cost and further training runs will be bottlenecked by compute or disk speed instead of by my paltry number of CPU cores.

I don’t see a way to do this in a streaming fashion – I could of course do Dataset.map and save_to_disk, but I believe this requires me to have enough ram to load the entire split before saving it.

I want to be able to define a per-batch augmentation function and point it at the existing dataset ID and a target dataset path, and it iterates through the dataset at whatever speed the CPU can maintain, preprocesses each batch, and appends it to the target parquet file on the fly. This way I can ‘remux’ a multi-hundred-GB dataset as long as I have enough disk space.

1 Like

If the order of the data isn’t particularly important, I think there are several possible methods, such as Streaming.

I appreciate the response, but I do not particularly trust what GPT et al have to say about HF’s libraries as they’re relatively new and the API has not historically been stable. I wanted to see if someone has solved this problem before.

1 Like

Under fitting is going to become your biggest problem. It’s only a tiny model. Never the less here’s a script my bot wrote for you.

# path: tools/stream_remux_hf_to_parquet.py
"""
Stream remux a Hugging Face dataset to Parquet shards with on CPU augmentations.

Why: avoid .map() materializing full splits; keep CPU bound transforms as a one time cost.
"""

import argparse
import io
import json
import math
import os
import signal
import sys
from dataclasses import dataclass
from functools import partial
from multiprocessing import Pool
from typing import Any, Dict, Iterable, Iterator, List, Optional, Tuple

import pyarrow as pa
import pyarrow.parquet as pq

from datasets import load_dataset, IterableDataset

try:
    from PIL import Image, ImageOps, ImageFilter
except Exception as exc:  # pragma: no cover
    raise RuntimeError("Pillow is required: pip install pillow") from exc


# --------------------------- Args & Config ---------------------------

@dataclass
class Config:
    dataset: str
    subset: Optional[str]
    split: str
    out_dir: str
    batch_size: int
    rows_per_shard: int
    num_workers: int
    seed: int
    image_col: str
    label_col: Optional[str]
    extra_cols: List[str]
    image_format: str
    jpeg_quality: int
    resume: bool


def parse_args(argv: Optional[List[str]] = None) -> Config:
    p = argparse.ArgumentParser(prog="stream-remux",
                                description="Stream a HF dataset, preprocess, and write Parquet shards.")
    p.add_argument("--dataset", required=True, help="HF dataset ID or local path.")
    p.add_argument("--subset", default=None, help="Dataset config/subset if applicable.")
    p.add_argument("--split", default="train", help="Split name (supports HF split syntax).")
    p.add_argument("--out-dir", required=True, help="Output directory.")
    p.add_argument("--batch-size", type=int, default=256, help="Streaming CPU batch size.")
    p.add_argument("--rows-per-shard", type=int, default=50_000, help="Rotate shard after N rows.")
    p.add_argument("--num-workers", type=int, default=max(1, os.cpu_count() or 1), help="Multiprocessing workers.")
    p.add_argument("--seed", type=int, default=42)
    p.add_argument("--image-col", default="image", help="Name of image column (HF Image or bytes/path).")
    p.add_argument("--label-col", default=None, help="Optional label column to preserve.")
    p.add_argument("--extra-cols", default="", help="Comma-separated extra column names to preserve.")
    p.add_argument("--image-format", choices=["jpeg", "png"], default="jpeg")
    p.add_argument("--jpeg-quality", type=int, default=90)
    p.add_argument("--resume", action="store_true", help="Resume if out-dir exists with partial shards.")
    ns = p.parse_args(argv)

    extra = [c for c in ns.extra_cols.split(",") if c.strip()] if ns.extra_cols else []
    cfg = Config(
        dataset=ns.dataset,
        subset=ns.subset,
        split=ns.split,
        out_dir=ns.out_dir,
        batch_size=ns.batch_size,
        rows_per_shard=ns.rows_per_shard,
        num_workers=ns.num_workers,
        seed=ns.seed,
        image_col=ns.image_col,
        label_col=ns.label_col,
        extra_cols=extra,
        image_format=ns.image_format,
        jpeg_quality=ns.jpeg_quality,
        resume=ns.resume,
    )
    return cfg


# --------------------------- Augmentation ---------------------------

def _rng(seed: int, i: int) -> int:
    # Why: cheap per-sample randomness without global state.
    return (seed * 0x9E3779B1 + i) & 0xFFFFFFFF


def _load_image(x: Any) -> Image.Image:
    if isinstance(x, dict) and "bytes" in x:  # HF Image feature yields dict
        return Image.open(io.BytesIO(x["bytes"])).convert("RGB")
    if isinstance(x, bytes):
        return Image.open(io.BytesIO(x)).convert("RGB")
    if isinstance(x, str) and os.path.exists(x):
        return Image.open(x).convert("RGB")
    if hasattr(x, "convert"):  # already PIL
        return x.convert("RGB")
    raise ValueError("Unsupported image payload for image column")


def _jpeg_bytes(img: Image.Image, quality: int) -> bytes:
    buf = io.BytesIO()
    img.save(buf, format="JPEG", quality=quality, optimize=True)
    return buf.getvalue()


def _png_bytes(img: Image.Image) -> bytes:
    buf = io.BytesIO()
    img.save(buf, format="PNG", optimize=True)
    return buf.getvalue()


def augment_record(
    example: Dict[str, Any],
    i: int,
    cfg: Config,
) -> Dict[str, Any]:
    """
    Customize this function to your pipeline. Keep outputs JSON/Arrow-friendly.
    """
    img = _load_image(example[cfg.image_col])

    # Simple CPU-bound aug: random resized crop + hflip + mild sharpen
    # Why: cheap, reproducible, avoids heavy libs.
    w, h = img.size
    rr = (_rng(cfg.seed, i) % 1000) / 1000.0
    scale = 0.7 + 0.3 * rr
    nw, nh = max(8, int(w * scale)), max(8, int(h * scale))
    left = max(0, (w - nw) // 2)
    top = max(0, (h - nh) // 2)
    img = img.crop((left, top, left + nw, top + nh)).resize((256, 256), Image.BICUBIC)
    if (_rng(cfg.seed ^ 0xABCDEF, i) & 1) == 1:
        img = ImageOps.mirror(img)
    img = img.filter(ImageFilter.SHARPEN)

    if cfg.image_format == "jpeg":
        payload = _jpeg_bytes(img, cfg.jpeg_quality)
    else:
        payload = _png_bytes(img)

    out: Dict[str, Any] = {
        "image_bytes": payload,
        "height": 256,
        "width": 256,
        "format": cfg.image_format,
    }
    if cfg.label_col and cfg.label_col in example:
        out["label"] = int(example[cfg.label_col]) if isinstance(example[cfg.label_col], (int,)) else example[cfg.label_col]
    for c in cfg.extra_cols:
        if c in example:
            out[c] = example[c]
    return out


# --------------------------- Batch Helpers ---------------------------

def batch_iter(it: Iterable[Dict[str, Any]], batch_size: int) -> Iterator[List[Dict[str, Any]]]:
    batch: List[Dict[str, Any]] = []
    for ex in it:
        batch.append(ex)
        if len(batch) >= batch_size:
            yield batch
            batch = []
    if batch:
        yield batch


def process_batch(
    batch: List[Dict[str, Any]],
    start_index: int,
    cfg: Config,
) -> List[Dict[str, Any]]:
    out: List[Dict[str, Any]] = []
    for j, ex in enumerate(batch):
        out.append(augment_record(ex, start_index + j, cfg))
    return out


# --------------------------- Writer ---------------------------

class ShardedParquetWriter:
    def __init__(self, out_dir: str, rows_per_shard: int) -> None:
        os.makedirs(out_dir, exist_ok=True)
        self.out_dir = out_dir
        self.rows_per_shard = rows_per_shard
        self.writer: Optional[pq.ParquetWriter] = None
        self.schema: Optional[pa.Schema] = None
        self.rows_in_shard = 0
        self.shard_idx = self._detect_resume_index()

    def _detect_resume_index(self) -> int:
        parts = [f for f in os.listdir(self.out_dir) if f.startswith("part-") and f.endswith(".parquet")]
        if not parts:
            return 0
        idxs = []
        for f in parts:
            try:
                idxs.append(int(f.replace("part-", "").replace(".parquet", "")))
            except ValueError:
                pass
        return (max(idxs) + 1) if idxs else 0

    def _next_path(self) -> str:
        return os.path.join(self.out_dir, f"part-{self.shard_idx:05d}.parquet")

    def _rotate(self) -> None:
        if self.writer:
            self.writer.close()
        self.rows_in_shard = 0
        self.shard_idx += 1
        self.writer = None  # recreated on next write

    def write_records(self, records: List[Dict[str, Any]]) -> None:
        if not records:
            return
        table = pa.Table.from_pylist(records, schema=self.schema)
        if self.schema is None:
            self.schema = table.schema
            path_tmp = self._next_path() + ".tmp"
            self.writer = pq.ParquetWriter(path_tmp, self.schema, compression="zstd")
        assert self.writer is not None
        self.writer.write_table(table)
        self.rows_in_shard += table.num_rows

        if self.rows_in_shard >= self.rows_per_shard:
            # atomic rename
            tmp_path = self._next_path() + ".tmp"
            final_path = self._next_path()
            if os.path.exists(tmp_path):
                os.replace(tmp_path, final_path)
            else:
                # already renamed; ignore
                pass
            self._rotate()

    def close(self) -> None:
        if self.writer:
            tmp_path = self._next_path() + ".tmp"
            final_path = self._next_path()
            self.writer.close()
            if os.path.exists(tmp_path):
                os.replace(tmp_path, final_path)
        self.writer = None


# --------------------------- Main ---------------------------

def install_sigint_handler() -> None:
    # Why: ensure clean close on Ctrl+C.
    signal.signal(signal.SIGINT, signal.SIG_DFL)


def main(argv: Optional[List[str]] = None) -> int:
    cfg = parse_args(argv)
    install_sigint_handler()
    os.makedirs(cfg.out_dir, exist_ok=True)

    # HF streaming loader
    ds_kwargs = dict(split=cfg.split, streaming=True)
    if cfg.subset:
        stream: IterableDataset = load_dataset(cfg.dataset, cfg.subset, **ds_kwargs)  # type: ignore
    else:
        stream = load_dataset(cfg.dataset, **ds_kwargs)  # type: ignore

    # Column projection early to reduce payload
    keep_cols = [cfg.image_col]
    if cfg.label_col:
        keep_cols.append(cfg.label_col)
    keep_cols.extend([c for c in cfg.extra_cols if c])
    stream = stream.remove_columns([c for c in stream.features.keys() if c not in set(keep_cols)])  # type: ignore

    writer = ShardedParquetWriter(cfg.out_dir, cfg.rows_per_shard)

    pool: Optional[Pool] = None
    if cfg.num_workers > 1:
        pool = Pool(processes=cfg.num_workers)
        map_fn = partial(_map_with_pool, pool=pool, cfg=cfg)
    else:
        map_fn = partial(_map_sync, cfg=cfg)

    total = 0
    try:
        start_idx = 0
        for batch in batch_iter(stream, cfg.batch_size):
            processed = map_fn(batch, start_idx)
            writer.write_records(processed)
            start_idx += len(batch)
            total += len(batch)
            if total % (cfg.batch_size * 20) == 0:
                sys.stderr.write(f"\rWrote {total:,} examples...")
                sys.stderr.flush()
        sys.stderr.write(f"\nDone. Total examples: {total:,}\n")
    finally:
        writer.close()
        if pool:
            pool.close()
            pool.join()

    # Write a tiny manifest
    manifest = {
        "dataset": cfg.dataset,
        "subset": cfg.subset,
        "split": cfg.split,
        "rows_per_shard": cfg.rows_per_shard,
        "total_examples": total,
        "image_format": cfg.image_format,
        "image_col": cfg.image_col,
        "label_col": cfg.label_col,
        "extra_cols": cfg.extra_cols,
    }
    with open(os.path.join(cfg.out_dir, "manifest.json"), "w", encoding="utf-8") as f:
        json.dump(manifest, f, indent=2)
    return 0


def _map_with_pool(batch: List[Dict[str, Any]], start_idx: int, pool: Pool, cfg: Config) -> List[Dict[str, Any]]:
    fn = partial(augment_record, cfg=cfg)
    # enumerate with absolute index for RNG
    args = [(ex, start_idx + j) for j, ex in enumerate(batch)]
    return pool.starmap(fn, args)


def _map_sync(batch: List[Dict[str, Any]], start_idx: int, cfg: Config) -> List[Dict[str, Any]]:
    return process_batch(batch, start_idx, cfg)


if __name__ == "__main__":
    raise SystemExit(main())

Usage:

pip install datasets pillow pyarrow
python tools/stream_remux_hf_to_parquet.py
–dataset Andron00e/Places365-custom
–split train
–out-dir ./places365-remux
–image-col image --label-col label
–batch-size 256 --rows-per-shard 20000
–num-workers 8 --image-format jpeg --jpeg-quality 90

1 Like

Once again, I’ve read what chatgpt has to say on the topic. I was more curious to see if anyone has actual experience with this specific problem

ChatGPT. No you have my words and then you have an engineering models bullshit free script. That’s what you have. If you think it’s anything else then you might want to consider therapy. I won’t be wasting anymore time assisting you.

Maybe something like this ?

ds = load_dataset(..., streaming=True)
ds = ds.map(...)

num_shards = ds.num_shards
for shard_idx in range(num_shards):  # you can parallelize this code
    ds.shard(num_shards=num_shards, index=shard_idx).to_parquet(f"{shard_idx:05d}-of-{num_shards:05d}.parquet")
1 Like

You do not need enough RAM to hold a 100 GB split to do this with datasets. Dataset.map() processes in chunks and writes intermediate Arrow cache files to disk, and the dataset itself is memory mapped from disk.

Option A: The simplest “remux” is map plus save_to_disk, with bounded memory

Tune writer_batch_size to control peak RAM during map. Smaller means less temporary memory.

import os
import datasets
from datasets import load_dataset

datasets.config.IN_MEMORY_MAX_SIZE = 0  # avoid accidental in-memory copies

ds = load_dataset("Andron00e/Places365-custom", split="train")

def preprocess(batch):
    # do your deterministic preprocessing here, for example decode, resize, center crop, normalize
    # return something like {"pixel_values": ...} or a new encoded image column
    return batch

ds2 = ds.map(
    preprocess,
    batched=True,
    batch_size=64,
    num_proc=min(4, os.cpu_count()),
    writer_batch_size=64,   # key knob for memory
    keep_in_memory=False,
)

ds2.save_to_disk("/path/to/places365_preprocessed")

Note: if your augmentations are random, baking them once removes that randomness for future epochs. Many people only bake deterministic steps (decode, resize, normalize) and keep light stochastic aug on GPU.

Option B: True streaming remux to sharded Parquet

If you want a strict stream and append style pipeline, load in streaming mode and write your own Parquet shards with PyArrow. HF does not currently let you directly save_to_disk() an IterableDataset without converting first.

Sketch:

from datasets import load_dataset
import pyarrow as pa
import pyarrow.parquet as pq

ds = load_dataset("Andron00e/Places365-custom", split="train", streaming=True)

def preprocess_batch(examples):
    # examples is a dict of lists
    # return dict of lists with processed tensors or encoded bytes
    return examples

writer = None
rows_written = 0
shard_id = 0

for batch in ds.iter(batch_size=64):
    out = preprocess_batch(batch)
    table = pa.Table.from_pydict(out)

    if writer is None:
        writer = pq.ParquetWriter(f"shards/train-{shard_id:05d}.parquet", table.schema, compression="zstd")

    writer.write_table(table)
    rows_written += table.num_rows

    if rows_written >= 200_000:  # rotate shard
        writer.close()
        writer = None
        rows_written = 0
        shard_id += 1

if writer is not None:
    writer.close()

Then upload the shards/ folder to the Hub using upload_large_folder or git-lfs.

Practical advice for images

  • If you store fully augmented pixel_values arrays, the remuxed dataset can get much larger than the original.

  • If your goal is speed, consider baking decode + resize + normalization, and keep random aug minimal, or move aug to GPU.

If you share what you are baking (resize only, random crop, color jitter, RandAugment, etc.) and what format you want to store (encoded JPEG bytes vs float tensors), I can suggest the best schema so the remuxed dataset stays fast without exploding in size.

1 Like