UCPE Causal-Forcing Checkpoints

Wan2.2-TI2V-5B + UCPE camera-control checkpoints for the causal video-generation pipeline at github.com/weijielyu/RayStream_CF (cf_ucpe repo).

All checkpoints are at 704Γ—1280 (720p), 81 frames @ 16 fps, TI2V with UCPE camera conditioning (relray_absmap, attn_compress=8, parallel cam_self_attn branches at every DiT block).

Repository layout

.
β”œβ”€β”€ README.md                                 # this file
β”œβ”€β”€ wan22_bidirectional_ucpe/                 # Wan2.2 bidirectional teacher (DeepSpeed ckpt, ~24 GB)
β”‚   β”œβ”€β”€ checkpoint/
β”‚   β”‚   β”œβ”€β”€ mp_rank_00_model_states.pt        # ← actual weights (21 GB)
β”‚   β”‚   └── bf16_zero_pp_rank_*.pt            # optimizer shards (8 Γ— 213 MB)
β”‚   β”œβ”€β”€ latest
β”‚   └── zero_to_fp32.py
β”‚
β”œβ”€β”€ ode_regression_wan21_sf/                  # Stage-1: causal student after DF-style ODE regression
β”‚   β”œβ”€β”€ checkpoint_model_000400/model.pt      # 400 steps  (~20 GB)
β”‚   └── checkpoint_model_001000/model.pt      # 1000 steps (~20 GB)
β”‚
β”œβ”€β”€ dmd_unfreeze_cam_wan21_sf/                # Stage-2 variant A: DMD with camera branch trainable (lr_cam=10x)
β”‚   β”œβ”€β”€ checkpoint_model_000500/model.pt      # 500 steps  (~135 GB, full-resume bundle)
β”‚   └── checkpoint_model_001000/model.pt      # 1000 steps (~135 GB)
β”‚
└── dmd_freeze_cam_wan21_sf/                  # Stage-2 variant B: DMD with camera branch frozen
    β”œβ”€β”€ checkpoint_model_000500/model.pt      # 500 steps  (~132 GB)
    └── checkpoint_model_001000/model.pt      # 1000 steps (~141 GB)

cf_ucpe ckpt format:

dir top-level keys
ode_regression_wan21_sf/*/model.pt generator
dmd_*_wan21_sf/*/model.pt generator, generator_ema, fake_score, generator_optimizer, critic_optimizer, step

For inference you only need generator_ema (DMD) or generator (ODE) β€” see scripts/extract_ema_ckpt.py in the code repo to slim them down.

The Wan2.2 bidirectional ckpt is in DeepSpeed Zero-3 layout. Code that loads it (e.g. UCPE/scripts/predict_one_sample.py) reads checkpoint/mp_rank_00_model_states.pt directly.


Quick start: download

huggingface-cli download wlyu/ucpe_checkpoints --local-dir ./ucpe_checkpoints

Or pull a specific subfolder:

huggingface-cli download wlyu/ucpe_checkpoints \
    --include 'dmd_unfreeze_cam_wan21_sf/checkpoint_model_001000/*' \
    --local-dir ./ucpe_checkpoints

Training (in cf_ucpe)

The full pipeline is Wan2.2 bidirectional β†’ ODE regression (causal student) β†’ DMD distillation.

Stage 1 β€” ODE regression (DF mode, matches upstream Self-Forcing)

Single causal forward, per-block random timesteps, no clean_x. Configured via use_df: true β†’ dispatches to model.ode_regression_df.ODERegressionDF.

Run on each of 4 nodes (set NODE_RANK=0..3):

LOG_DIR=output/ucpe_training_720_v2/ode_regression_wan21_sf \
  CONFIG=configs/ucpe_ode_regression_720_wan21_sf.yaml \
  NODE_RANK=0 MASTER_PORT=36903 MASTER_ADDR=<node0-ip> \
  bash scripts/run_ode_regression_720_multinode.sh

Saves at every 200 steps; ~1000 steps total is enough.

Stage 2 β€” DMD distillation

Distills the causal student against the bidirectional teacher (Wan2.2 + UCPE). Two variants:

  • ucpe_causal_forcing_dmd_720_wan21_sf.yaml β€” camera branch trainable, with lr_cam_multiplier=10 (default in trainer/distillation.py).
  • ucpe_causal_forcing_dmd_720_wan21_sf_freeze.yaml β€” freeze_camera_branch: true, camera branch participates in the forward pass but receives no gradient.
LOG_DIR=output/ucpe_training_720_v2/dmd_wan21_sf \
  CONFIG=configs/ucpe_causal_forcing_dmd_720_wan21_sf.yaml \
  NODE_RANK=0 MASTER_PORT=34576 MASTER_ADDR=<node0-ip> \
  bash scripts/run_dmd_720_multinode.sh

Each step takes ~17 s on 4Γ—8 H100. ~1000 steps recommended.


Inference

DMD causal student (few-step, fast)

python scripts/test_ucpe_dmd.py \
  --config_path configs/ucpe_causal_forcing_dmd_720_wan21_sf.yaml \
  --checkpoint_path /path/to/dmd_unfreeze_cam_wan21_sf/checkpoint_model_001000/model.pt \
  --output_folder ./output/test \
  --use_ema \
  --num_samples 8

--use_ema is required for DMD checkpoints (loads generator_ema). Skip it for ODE-stage checkpoints (only generator).

Wan2.2 bidirectional teacher (50-step, source-of-truth)

The bidirectional ckpt was trained against UCPE's diffsynth-based pipeline. Run via UCPE's scripts/predict_one_sample.py:

cd /path/to/UCPE      # the UCPE repo, NOT cf_ucpe
HF_HUB_OFFLINE=1 python scripts/predict_one_sample.py \
  --video_id <panshot_video_id> \
  --ckpt_path /path/to/wan22_bidirectional_ucpe \
  --output_path ./bidir.mp4 \
  --num_inference_steps 50

Pick by --video_id (recommended) or --sample_idx for the test split.


Visualization (4-panel comparison)

Generates GT / camera-trajectory / Wan2.2 bidirectional / DMD as a 2Γ—2 grid mp4 for one PanShot test sample:

# 1. Run all four sources for one sample (writes to output/comparison/<sample_dir>/)
python scripts/compare_inference.py \
  --config_path configs/ucpe_causal_forcing_dmd_720_wan21_sf.yaml \
  --dmd_ckpt /path/to/dmd_unfreeze_cam_wan21_sf/checkpoint_model_001000/model.pt \
  --use_ema \
  --sample_idx 0 \
  --output_root output/comparison

# 2. Compose the 2x2 grid (renders camera trajectory + ffmpeg stack)
python scripts/compare_grid.py --input_dir output/comparison/0000_<video_id>/

Output: output/comparison/0000_<video_id>/grid.mp4.

For a batch over 8 GPUs (samples 0..31, ~5 min):

START=0 END=31 bash scripts/compare_batch_8gpu.sh

The trajectory is rendered as a 3D camera frustum gizmo over the actual world-space camera path. Frustum size auto-scales to the trajectory bbox; tweak with --frustum_scale_ratio (default 1/12) on compare_grid.py without redoing inference.


Citation / contact

Code: https://github.com/weijielyu/RayStream_CF Author: Weijie Lyu (weijielyu1@gmail.com)

Downloads last month
-
Inference Providers NEW
This model isn't deployed by any Inference Provider. πŸ™‹ Ask for provider support