UCPE Causal-Forcing Checkpoints
Wan2.2-TI2V-5B + UCPE camera-control checkpoints for the causal video-generation pipeline at
github.com/weijielyu/RayStream_CF (cf_ucpe repo).
All checkpoints are at 704Γ1280 (720p), 81 frames @ 16 fps, TI2V with UCPE
camera conditioning (relray_absmap, attn_compress=8, parallel cam_self_attn
branches at every DiT block).
Repository layout
.
βββ README.md # this file
βββ wan22_bidirectional_ucpe/ # Wan2.2 bidirectional teacher (DeepSpeed ckpt, ~24 GB)
β βββ checkpoint/
β β βββ mp_rank_00_model_states.pt # β actual weights (21 GB)
β β βββ bf16_zero_pp_rank_*.pt # optimizer shards (8 Γ 213 MB)
β βββ latest
β βββ zero_to_fp32.py
β
βββ ode_regression_wan21_sf/ # Stage-1: causal student after DF-style ODE regression
β βββ checkpoint_model_000400/model.pt # 400 steps (~20 GB)
β βββ checkpoint_model_001000/model.pt # 1000 steps (~20 GB)
β
βββ dmd_unfreeze_cam_wan21_sf/ # Stage-2 variant A: DMD with camera branch trainable (lr_cam=10x)
β βββ checkpoint_model_000500/model.pt # 500 steps (~135 GB, full-resume bundle)
β βββ checkpoint_model_001000/model.pt # 1000 steps (~135 GB)
β
βββ dmd_freeze_cam_wan21_sf/ # Stage-2 variant B: DMD with camera branch frozen
βββ checkpoint_model_000500/model.pt # 500 steps (~132 GB)
βββ checkpoint_model_001000/model.pt # 1000 steps (~141 GB)
cf_ucpe ckpt format:
| dir | top-level keys |
|---|---|
ode_regression_wan21_sf/*/model.pt |
generator |
dmd_*_wan21_sf/*/model.pt |
generator, generator_ema, fake_score, generator_optimizer, critic_optimizer, step |
For inference you only need generator_ema (DMD) or generator (ODE) β see
scripts/extract_ema_ckpt.py in the code repo to slim them down.
The Wan2.2 bidirectional ckpt is in DeepSpeed Zero-3 layout. Code that loads
it (e.g. UCPE/scripts/predict_one_sample.py) reads
checkpoint/mp_rank_00_model_states.pt directly.
Quick start: download
huggingface-cli download wlyu/ucpe_checkpoints --local-dir ./ucpe_checkpoints
Or pull a specific subfolder:
huggingface-cli download wlyu/ucpe_checkpoints \
--include 'dmd_unfreeze_cam_wan21_sf/checkpoint_model_001000/*' \
--local-dir ./ucpe_checkpoints
Training (in cf_ucpe)
The full pipeline is Wan2.2 bidirectional β ODE regression (causal student) β DMD distillation.
Stage 1 β ODE regression (DF mode, matches upstream Self-Forcing)
Single causal forward, per-block random timesteps, no clean_x. Configured via
use_df: true β dispatches to model.ode_regression_df.ODERegressionDF.
Run on each of 4 nodes (set NODE_RANK=0..3):
LOG_DIR=output/ucpe_training_720_v2/ode_regression_wan21_sf \
CONFIG=configs/ucpe_ode_regression_720_wan21_sf.yaml \
NODE_RANK=0 MASTER_PORT=36903 MASTER_ADDR=<node0-ip> \
bash scripts/run_ode_regression_720_multinode.sh
Saves at every 200 steps; ~1000 steps total is enough.
Stage 2 β DMD distillation
Distills the causal student against the bidirectional teacher (Wan2.2 + UCPE). Two variants:
ucpe_causal_forcing_dmd_720_wan21_sf.yamlβ camera branch trainable, withlr_cam_multiplier=10(default intrainer/distillation.py).ucpe_causal_forcing_dmd_720_wan21_sf_freeze.yamlβfreeze_camera_branch: true, camera branch participates in the forward pass but receives no gradient.
LOG_DIR=output/ucpe_training_720_v2/dmd_wan21_sf \
CONFIG=configs/ucpe_causal_forcing_dmd_720_wan21_sf.yaml \
NODE_RANK=0 MASTER_PORT=34576 MASTER_ADDR=<node0-ip> \
bash scripts/run_dmd_720_multinode.sh
Each step takes ~17 s on 4Γ8 H100. ~1000 steps recommended.
Inference
DMD causal student (few-step, fast)
python scripts/test_ucpe_dmd.py \
--config_path configs/ucpe_causal_forcing_dmd_720_wan21_sf.yaml \
--checkpoint_path /path/to/dmd_unfreeze_cam_wan21_sf/checkpoint_model_001000/model.pt \
--output_folder ./output/test \
--use_ema \
--num_samples 8
--use_ema is required for DMD checkpoints (loads generator_ema). Skip it
for ODE-stage checkpoints (only generator).
Wan2.2 bidirectional teacher (50-step, source-of-truth)
The bidirectional ckpt was trained against UCPE's diffsynth-based pipeline.
Run via UCPE's scripts/predict_one_sample.py:
cd /path/to/UCPE # the UCPE repo, NOT cf_ucpe
HF_HUB_OFFLINE=1 python scripts/predict_one_sample.py \
--video_id <panshot_video_id> \
--ckpt_path /path/to/wan22_bidirectional_ucpe \
--output_path ./bidir.mp4 \
--num_inference_steps 50
Pick by --video_id (recommended) or --sample_idx for the test split.
Visualization (4-panel comparison)
Generates GT / camera-trajectory / Wan2.2 bidirectional / DMD as a 2Γ2 grid mp4 for one PanShot test sample:
# 1. Run all four sources for one sample (writes to output/comparison/<sample_dir>/)
python scripts/compare_inference.py \
--config_path configs/ucpe_causal_forcing_dmd_720_wan21_sf.yaml \
--dmd_ckpt /path/to/dmd_unfreeze_cam_wan21_sf/checkpoint_model_001000/model.pt \
--use_ema \
--sample_idx 0 \
--output_root output/comparison
# 2. Compose the 2x2 grid (renders camera trajectory + ffmpeg stack)
python scripts/compare_grid.py --input_dir output/comparison/0000_<video_id>/
Output: output/comparison/0000_<video_id>/grid.mp4.
For a batch over 8 GPUs (samples 0..31, ~5 min):
START=0 END=31 bash scripts/compare_batch_8gpu.sh
The trajectory is rendered as a 3D camera frustum gizmo over the actual world-space
camera path. Frustum size auto-scales to the trajectory bbox; tweak with
--frustum_scale_ratio (default 1/12) on compare_grid.py without redoing inference.
Citation / contact
Code: https://github.com/weijielyu/RayStream_CF
Author: Weijie Lyu (weijielyu1@gmail.com)
- Downloads last month
- -