Add NVFP4 Conv3d export for diffusers VAE (Wan 2.2)#1809
Conversation
|
No actionable comments were generated in the recent review. 🎉 ℹ️ Recent review info⚙️ Run configurationConfiguration used: Path: .coderabbit.yaml Review profile: CHILL Plan: Enterprise Run ID: 📒 Files selected for processing (7)
✅ Files skipped from review due to trivial changes (2)
🚧 Files skipped from review as they are similar to previous changes (5)
📝 WalkthroughWalkthroughAdds NVFP4 Conv3d weight export support to the unified Hugging Face diffusers export pipeline. A new ChangesNVFP4 Conv3d export pipeline
Sequence Diagram(s)sequenceDiagram
participant CLI as Diffusers Export CLI
participant Exporter as _export_diffusers_checkpoint
participant Process as _process_quantized_modules
participant ConvPack as _export_quantized_conv_weight
participant Post as _postprocess_safetensors
participant Pad as pad_nvfp4_weights
participant Swizzle as swizzle_nvfp4_scales
CLI->>Exporter: --hf-ckpt-dir provided
Exporter->>Exporter: compute conv_nvfp4_prefixes via is_quantconv3d
Exporter->>Process: iterate named modules
Process->>Process: is_quantconv3d(sub_module)?
Process->>ConvPack: fsdp2_aware_weight_update → flatten/pad/quantize
ConvPack-->>Process: weight(uint8), weight_scale(fp32), weight_scale_2(fp32)
Exporter->>Post: save safetensors, nvfp4_exclude_layers=conv_nvfp4_prefixes
Post->>Pad: state_dict, exclude_layers=conv_nvfp4_prefixes
Post->>Swizzle: state_dict, exclude_layers=conv_nvfp4_prefixes
Post-->>Exporter: safetensors written (conv tensors unchanged by pad/swizzle)
Estimated code review effort🎯 4 (Complex) | ⏱️ ~60 minutes Possibly related PRs
Suggested reviewers
🚥 Pre-merge checks | ✅ 5 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (5 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches📝 Generate docstrings
🧪 Generate unit tests (beta)
Warning There were issues while running some tools. Please review the errors and either fix the tool's configuration or disable the tool if it's a critical failure. 🔧 OpenGrep (1.23.0)modelopt/torch/export/layer_utils.py┌──────────────┐ �[32m✔�[39m �[1mOpengrep OSS�[0m [00.15][ERROR]: unable to find a config; path modelopt/torch/export/diffusers_utils.py┌──────────────┐ �[32m✔�[39m �[1mOpengrep OSS�[0m [00.18][ERROR]: unable to find a config; path tests/unit/torch/export/test_nvfp4_conv_export_diffusers.py┌──────────────┐ �[32m✔�[39m �[1mOpengrep OSS�[0m [00.17][ERROR]: unable to find a config; path
🔧 markdownlint-cli2 (0.22.1)examples/diffusers/README.mdmarkdownlint-cli2 v0.22.1 (markdownlint v0.40.0) Comment |
|
/claude review |
|
There was a problem hiding this comment.
Warning
CodeRabbit couldn't request changes on this pull request because it doesn't have sufficient GitHub permissions.
Please grant CodeRabbit Pull requests: Read and write permission and re-run the review.
Actionable comments posted: 1
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
tests/unit/torch/export/test_nvfp4_conv_export_diffusers.py (1)
108-123: 🎯 Functional Correctness | 🟡 Minor | ⚡ Quick winScope on-disk schema checks to the Conv3d modules discovered earlier.
conv_prefixesis currently built from every*.weight_scale_2key, so this test can pass even if Conv3d export regresses but other NVFP4 layers remain valid.Suggested patch
- conv_prefixes = [ - k[: -len(".weight_scale_2")] for k in state_dict if k.endswith(".weight_scale_2") - ] - assert conv_prefixes, "no NVFP4 conv layers found on disk" - for prefix in conv_prefixes: + exported_prefixes = { + k[: -len(".weight_scale_2")] for k in state_dict if k.endswith(".weight_scale_2") + } + expected_conv_prefixes = set(quant_convs) + missing = expected_conv_prefixes - exported_prefixes + assert not missing, f"missing exported NVFP4 conv layers on disk: {sorted(missing)}" + for prefix in sorted(expected_conv_prefixes): weight = state_dict[f"{prefix}.weight"] scale = state_dict[f"{prefix}.weight_scale"] scale_2 = state_dict[f"{prefix}.weight_scale_2"]🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@tests/unit/torch/export/test_nvfp4_conv_export_diffusers.py` around lines 108 - 123, The conv_prefixes list is currently built from all state_dict keys ending with .weight_scale_2, which means the subsequent assertion checks can pass even if Conv3d export regresses. Instead, build conv_prefixes from the list of Conv3d modules that were discovered earlier in the test (before this section), so that the assertions in this block only validate the Conv3d modules that were actually expected to be exported. This ensures the test properly scopes its validation to the specific Conv3d modules rather than all NVFP4 layers present in the state_dict.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Inline comments:
In `@tests/unit/torch/export/test_nvfp4_conv_export.py`:
- Line 318: Move the `from safetensors.torch import load_file, save_file` import
statement from line 318 inside the test function to the module scope at the top
of the file with the other imports. This ensures import failures are caught at
module load time rather than at runtime, and follows the project's import
convention that requires imports at the top of the file unless there is a
specific reason (circular imports or optional dependencies) to place them inside
functions.
---
Outside diff comments:
In `@tests/unit/torch/export/test_nvfp4_conv_export_diffusers.py`:
- Around line 108-123: The conv_prefixes list is currently built from all
state_dict keys ending with .weight_scale_2, which means the subsequent
assertion checks can pass even if Conv3d export regresses. Instead, build
conv_prefixes from the list of Conv3d modules that were discovered earlier in
the test (before this section), so that the assertions in this block only
validate the Conv3d modules that were actually expected to be exported. This
ensures the test properly scopes its validation to the specific Conv3d modules
rather than all NVFP4 layers present in the state_dict.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Enterprise
Run ID: 8ecadbbc-27bb-42c7-ba9f-6e8897cfd119
📒 Files selected for processing (8)
CHANGELOG.rstexamples/diffusers/README.mdexamples/diffusers/quantization/quantize.pymodelopt/torch/export/diffusers_utils.pymodelopt/torch/export/layer_utils.pymodelopt/torch/export/unified_export_hf.pytests/unit/torch/export/test_nvfp4_conv_export.pytests/unit/torch/export/test_nvfp4_conv_export_diffusers.py
|
|
||
| def test_postprocess_safetensors_excludes_conv(tmp_path): | ||
| """Conv stays logical on disk when pad/swizzle are enabled for other layers.""" | ||
| from safetensors.torch import load_file, save_file |
There was a problem hiding this comment.
📐 Maintainability & Code Quality | 🟠 Major | ⚡ Quick win
Move safetensors imports to module scope.
Line 318 imports inside a test function without justification; this delays import failures until runtime and violates the test import convention.
As per path instructions, “Imports belong at the top of the file ... The only acceptable in-function imports are for circular imports or optional dependencies ... with a brief comment naming the reason.”
Suggested patch
import pytest
import torch
import torch.nn as nn
+from safetensors.torch import load_file, save_file
import modelopt.torch.quantization as mtq
from modelopt.torch.export.diffusers_utils import (
@@
def test_postprocess_safetensors_excludes_conv(tmp_path):
"""Conv stays logical on disk when pad/swizzle are enabled for other layers."""
- from safetensors.torch import load_file, save_file
sd = {**_mk_nvfp4_layer("transformer.proj", 64, 256), **_mk_nvfp4_layer("vae.conv1", 120, 240)}🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
In `@tests/unit/torch/export/test_nvfp4_conv_export.py` at line 318, Move the
`from safetensors.torch import load_file, save_file` import statement from line
318 inside the test function to the module scope at the top of the file with the
other imports. This ensures import failures are caught at module load time
rather than at runtime, and follows the project's import convention that
requires imports at the top of the file unless there is a specific reason
(circular imports or optional dependencies) to place them inside functions.
Source: Path instructions
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #1809 +/- ##
===========================================
+ Coverage 62.89% 75.86% +12.97%
===========================================
Files 511 511
Lines 56632 58620 +1988
===========================================
+ Hits 35616 44472 +8856
+ Misses 21016 14148 -6868
Flags with carried forward coverage won't be shown. Click here to find out more. ☔ View full report in Codecov by Harness. 🚀 New features to boost your workflow:
|
cjluo-nv
left a comment
There was a problem hiding this comment.
Bot review — DM the bot to share feedback.
Adds NVFP4 Conv3d weight export for the diffusers Wan 2.2 VAE to the unified HF export path. +636/-18, 8 files. I traced the core logic and it holds up: the conv flatten weight.reshape(O, -1) matches the calibration-time flatten in _nvfp4_quantize_weight_along_k (so byte-exactness is plausible and is pinned by a real test); the pad-before-scale ordering is correct (get_weights_scaling_factor requires block-divisible last dim); the static-quantizer rejection via NVFP4QTensor._is_static_quantizer is sound and tested; is_quantconv3d correctly matches _QuantDiffusersWanCausalConv3d/QuantConv3d while excluding Conv1d/Conv2d (no "Conv3d" substring) and ConvTranspose; and the nvfp4_exclude_layers plumbing correctly keeps conv prefixes out of the Linear-targeted pad/swizzle. New optional kwargs are backward-compatible. The lazy ONNX import in quantize.py has a valid documented justification (heavy optional dependency). Test coverage is strong (byte-exact vs NVFP4QTensor.quantize, dequant round-trip, schema, K-order sensitivity, dispatch routing, quantizer-hiding, pad/swizzle exclusion, plus a tiny real Wan VAE e2e). New test files' license headers match the canonical LICENSE_HEADER (standard-header exception applies — no licensing concern). No prompt-injection in PR text.
Why nudge rather than approve:
hide_quantizers_from_state_dictwas generalized from "strip QuantLinear's weight/input/output quantizers" to "strip every*_quantizerchild of every module." This now also strips attention bmm/softmax quantizers across ALL diffusers exports, not just conv. My read is this is a cleanup (the unified HF safetensors path never converted attention quantizers into usable scale buffers — FP8 MHA scales go through the ONNXexport_fp8_mhasymbolic path, so the previously-serialized*_quantizer._amaxkeys were unusable junk), but it changes a shared code path with broad blast radius (Flux/SD3/Wan/LTX-2) and only has manual GPU e2e validation.- The real numeric conv export was only manually validated on a Wan 2.2 5B VAE (48 layers); CI coverage is CPU-only.
- Minor: some scale-derivation/input_scale duplication between
_export_quantized_conv_weightand_export_quantized_weight, acceptable given the conv-specific flatten/static-reject.
Recommend a diffusers-export owner confirm the hide_quantizers_from_state_dict broadening is safe for the FP8-MHA/Linear paths and that the GPU conv export numerics are signed off.
Route quantized Conv3d modules through the unified HuggingFace diffusers export so their weights serialize in the logical flattened-K NVFP4 schema, matching NVFP4 Linear. Previously such modules fell through `_process_quantized_modules` unpacked and leaked quantizer buffers. - layer_utils: add `is_quantconv3d` (Conv3d-scoped; excludes ConvTranspose / Conv2d / Conv1d). - unified_export_hf: add `_export_quantized_conv_weight` + a dispatch branch. Flatten `[O,C,kt,kh,kw] -> [O,K_flat]`, pad K to a multiple of 16, derive scales from the flattened weight (dynamic NVFP4), and pack to `weight [O,K_pad/2]` uint8 + `weight_scale [O,K_pad/16]` fp8 + scalar `weight_scale_2` (+ `input_scale` when the activation amax is calibrated). Static NVFP4 conv is rejected. - diffusers_utils: strip every `*_quantizer` child in `hide_quantizers_from_state_dict` so no quantizer state leaks; exclude Conv3d layers from the opt-in `pad_nvfp4_weights` / `swizzle_nvfp4_scales` post-processing (conv stays logical; the conv kernel does its own layout prep). - examples/diffusers: document the Wan 2.2 VAE NVFP4 `--hf-ckpt-dir` command. - tests: CPU coverage (byte-exact vs `NVFP4QTensor.quantize`, dequant round-trip, schema, no-leak, ConvTranspose/static exclusion, pad/swizzle conv-exclusion) + a tiny Wan VAE save/reload test. Verified end to end: a Wan 2.2 5B VAE calib-32 export produces 48 NVFP4 conv layers, all schema-valid, with zero quantizer keys on disk. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com> Signed-off-by: Jingyu Xin <jingyux@nvidia.com>
e0ebadb to
ebb674f
Compare
There was a problem hiding this comment.
Claude review passed — no blocking issues found. LGTM
Scope reviewed: All 8 changed files (modelopt/torch/export/{diffusers_utils,layer_utils,unified_export_hf}.py, examples/diffusers/quantization/quantize.py, both new test files, README, CHANGELOG). Note: the two-dot git diff origin/main HEAD was polluted by unrelated MoE-refactor commits already on main; I reviewed against the authoritative gh pr diff (8 files).
What I verified:
- Algorithm correctness — The export-time flatten
weight.reshape(O, -1)exactly matches the calibration-time flatten in_nvfp4_quantize_weight_along_k(quant_conv.py:81), so per-block grouping is identical between calibration and export. Grouped-conv reduction dimK = (in_channels // groups) * kt*kh*kwis correctly captured. Padding K with zeros leaves amax unchanged, soweight_scale_2and per-block scales are stable and the byte-exact round-trip holds. - Scope guards — static NVFP4 weight quantizer is rejected loudly (
NotImplementedError); non-NVFP4 conv warns and leaves the weight unpacked. Both are strict improvements over the prior fall-through that leaked quantizer buffers. - Mode/State — the generalized
hide_quantizers_from_state_dictstrips every*_quantizerchild and restores on exit; the delete-during-named_modules()walk is safe (the child_modulesmutation precedes that level's.items()iterator) and matches the prior pattern. - Export compatibility — Conv3d prefixes are correctly excluded from
pad_nvfp4_weights/swizzle_nvfp4_scalesvianvfp4_exclude_layers, keeping the logical flattened-K layout for the downstream conv kernel.input_scaleemission mirrors the NVFP4 Linear path verbatim. - Routing —
is_quantconv3dmatches_QuantConv3d/QuantConv3d/_QuantDiffusersWanCausalConv3d, and excludes Conv1d/Conv2d and all ConvTranspose variants; consistent string-matching style withis_quantlinear. - Lazy ONNX import in
quantize.pyis confined toexport_onnx(its only use site).
Additive, backward-compatible, and covered by focused CPU unit + e2e tests. No correctness, mode/state, export, compatibility, or performance concerns found.
| export_dir: Path, | ||
| pipe: Any | None = None, | ||
| hf_quant_config: dict | None = None, | ||
| nvfp4_exclude_layers: set[str] | None = None, |
There was a problem hiding this comment.
what are the layers that we would like to exclude? can we infer by naming instead of introducing a new arg?
What does this PR do?
Type of change: new feature
Adds NVFP4 Conv3d weight export to the unified Hugging Face diffusers export. Quantized
Conv3dlayers — concretely the Wan 2.2 VAEWanCausalConv3dstack — are serialized in the same logical flattened-K NVFP4 schema already used for NVFP4Linear. Each filter[O, C, kt, kh, kw]is flattened to[O, K_flat](PyTorch-contiguous),K_flatis padded to a multiple of the block size 16, and the result is stored as packedweight([O, K_pad/2]uint8), per-blockweight_scale([O, K_pad/16]FP8 E4M3) and a scalarweight_scale_2;input_scaleis emitted when the activation amax is calibrated.Previously these modules fell through
_process_quantized_modulesunpacked and could leak quantizer buffers. This PR:is_quantconv3dpredicate (Conv3d-scoped; excludes ConvTranspose / Conv2d / Conv1d) and a conv branch in_process_quantized_modules;hide_quantizers_from_state_dictto strip every*_quantizerchild so no quantizer state (_amax) is serialized;pad_nvfp4_weights/swizzle_nvfp4_scalespost-processing so conv stays in logical layout (kernel-side layout preparation is a downstream-runtime concern).Scope: dynamic NVFP4 Conv3d. Out of scope (downstream or future work): Conv2d / ConvTranspose / Conv1d packing, static/MSE NVFP4 conv, and kernel-side 128x4 SF swizzle / channel alignment / KTRSC repack / runtime alpha.
Usage
python quantize.py \ --model wan2.2-t2v-5b --backbone vae \ --format fp4 --quant-algo max --collect-method default \ --model-dtype BFloat16 --trt-high-precision-dtype BFloat16 \ --batch-size 1 --calib-size 32 --n-steps 30 \ --hf-ckpt-dir ./wan22_vae_nvfp4_hfTesting
tests/unit/torch/export/: byte-exact packing vsNVFP4QTensor.quantize, dequant round-trip, exported-tensor schema (dtypes/shapes/scalarweight_scale_2), no quantizer-state leakage, ConvTranspose and static-NVFP4 exclusion, and a pad/swizzle conv-exclusion regression; plus a tinyAutoencoderKLWansave/reload test.Before your PR is "Ready for review"
CONTRIBUTING.md: N/A (no new dependencies; reuses existing NVFP4 helpers)./claude reviewon the PR.Additional Information
Kernel-side conv layout (128x4 SF swizzle, channel alignment, KTRSC repack) and the runtime
alphaare intentionally left to the downstream runtime (e.g. TRT-LLM); ModelOpt stores only the logical checkpoint. The produced Wan 2.2 5B VAE checkpoint itself is kept out of this PR.Summary by CodeRabbit
Release Notes
New Features
Documentation
--hf-ckpt-dircheckpoint generation.Tests