diff --git a/diffsynth/utils/quantisation/__init__.py b/diffsynth/utils/quantisation/__init__.py new file mode 100644 index 00000000..ee57f795 --- /dev/null +++ b/diffsynth/utils/quantisation/__init__.py @@ -0,0 +1,200 @@ +import logging +import torch + +logger = logging.getLogger(__name__) + + +def _quanto_type_map(model_precision: str): + if model_precision is None or model_precision == "no_change": + return None + from optimum.quanto import qfloat8, qfloat8_e4m3fnuz, qint2, qint4, qint8 + + mp = model_precision.lower() + if mp == "int2-quanto": + return qint2 + elif mp == "int4-quanto": + return qint4 + elif mp == "int8-quanto": + return qint8 + elif mp in ("fp8-quanto", "fp8uz-quanto"): + if torch.backends.mps.is_available(): + logger.warning( + "MPS doesn't support dtype float8, please use bf16/fp16/int8-quanto instead." + ) + return None + return qfloat8 if mp == "fp8-quanto" else qfloat8_e4m3fnuz + else: + raise ValueError(f"Invalid quantisation level: {model_precision}") + + +def _quanto_model( + model, + model_precision, + base_model_precision=None, + quantize_activations: bool = False, +): + try: + from optimum.quanto import quantize, freeze # noqa + # 仅仅 import,就会触发 quanto_workarounds 里的 monkeypatch + from diffsynth.utils.quantisation import quanto_workarounds # noqa: F401 + except ImportError as e: + raise ImportError( + "To use Quanto, please install the optimum library: `pip install \"optimum[quanto]\"`" + ) from e + + if model is None: + return model + if model_precision is None: + model_precision = base_model_precision + if model_precision in (None, "no_change"): + logger.info("...No quantisation applied to %s.", model.__class__.__name__) + return model + + logger.info("Quantising %s. Using %s.", model.__class__.__name__, model_precision) + weight_quant = _quanto_type_map(model_precision) + if weight_quant is None: + logger.info("Quantisation level %s resolved to None, skipping.", model_precision) + return model + + extra_quanto_args = {} + if quantize_activations: + logger.info("Quanto: Freezing model weights and activations") + extra_quanto_args["activations"] = weight_quant + else: + logger.info("Quanto: Freezing model weights only") + + quantize(model, weights=weight_quant, **extra_quanto_args) + freeze(model) + return model + + +def get_quant_fn(base_model_precision): + if base_model_precision is None: + return None + precision = base_model_precision.lower() + if precision == "no_change": + return None + if "quanto" in precision: + return _quanto_model + # 这里先不支持 torchao + return None + + +def quantise_model( + model=None, + text_encoders: list = None, + controlnet=None, + ema=None, + args=None, + return_dict: bool = False, +): + # 展开 text_encoders,最多支持 4 个以兼容 SimpleTuner 的接口 + te1 = te2 = te3 = te4 = None + if text_encoders is not None: + if len(text_encoders) > 0: + te1 = text_encoders[0] + if len(text_encoders) > 1: + te2 = text_encoders[1] + if len(text_encoders) > 2: + te3 = text_encoders[2] + if len(text_encoders) > 3: + te4 = text_encoders[3] + + models = [ + ( + model, + { + "quant_fn": get_quant_fn(args.base_model_precision), + "model_precision": args.base_model_precision, + "quantize_activations": getattr(args, "quantize_activations", False), + }, + ), + ( + controlnet, + { + "quant_fn": get_quant_fn(args.base_model_precision), + "model_precision": args.base_model_precision, + "quantize_activations": getattr(args, "quantize_activations", False), + }, + ), + ( + te1, + { + "quant_fn": get_quant_fn(args.text_encoder_1_precision), + "model_precision": args.text_encoder_1_precision, + "base_model_precision": args.base_model_precision, + }, + ), + ( + te2, + { + "quant_fn": get_quant_fn(args.text_encoder_2_precision), + "model_precision": args.text_encoder_2_precision, + "base_model_precision": args.base_model_precision, + }, + ), + ( + te3, + { + "quant_fn": get_quant_fn(args.text_encoder_3_precision), + "model_precision": args.text_encoder_3_precision, + "base_model_precision": args.base_model_precision, + }, + ), + ( + te4, + { + "quant_fn": get_quant_fn(args.text_encoder_4_precision), + "model_precision": args.text_encoder_4_precision, + "base_model_precision": args.base_model_precision, + }, + ), + ( + ema, + { + "quant_fn": get_quant_fn(args.base_model_precision), + "model_precision": args.base_model_precision, + "quantize_activations": getattr(args, "quantize_activations", False), + }, + ), + ] + + for i, (m, qargs) in enumerate(models): + quant_fn = qargs["quant_fn"] + if m is None or quant_fn is None: + continue + quant_args_combined = { + "model_precision": qargs["model_precision"], + "base_model_precision": qargs.get("base_model_precision", args.base_model_precision), + "quantize_activations": qargs.get( + "quantize_activations", getattr(args, "quantize_activations", False) + ), + } + logger.info("Quantising %s with %s", m.__class__.__name__, quant_args_combined) + models[i] = (quant_fn(m, **quant_args_combined), qargs) + + # 解包 + model, controlnet, te1, te2, te3, te4, ema = [m for (m, _) in models] + + # 重新打包 text_encoders + new_text_encoders = [] + if te1 is not None: + new_text_encoders.append(te1) + if te2 is not None: + new_text_encoders.append(te2) + if te3 is not None: + new_text_encoders.append(te3) + if te4 is not None: + new_text_encoders.append(te4) + if len(new_text_encoders) == 0: + new_text_encoders = None + + if return_dict: + return { + "model": model, + "text_encoders": new_text_encoders, + "controlnet": controlnet, + "ema": ema, + } + + return model, new_text_encoders, controlnet, ema \ No newline at end of file diff --git a/diffsynth/utils/quantisation/peft_workarounds.py b/diffsynth/utils/quantisation/peft_workarounds.py new file mode 100644 index 00000000..b4868d74 --- /dev/null +++ b/diffsynth/utils/quantisation/peft_workarounds.py @@ -0,0 +1,406 @@ +# Copyright 2024-present the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import annotations + +import math +import warnings +from typing import Any, Optional + +import torch + +try: + from peft.import_utils import is_quanto_available # type: ignore +except Exception: + import importlib.util + + def is_quanto_available() -> bool: + return importlib.util.find_spec("optimum.quanto") is not None + +from peft.tuners.lora.layer import LoraLayer +from peft.tuners.tuners_utils import BaseTunerLayer, check_adapters_to_merge +from peft.utils.other import transpose +from torch import nn +from torch.nn import functional as F + +if is_quanto_available: + # ensure that there are no quanto imports unless optimum.quanto is installed + from optimum.quanto import QConv2d, QLinear +else: + QConv2d, QLinear = None, None + + +class QuantoLoraLinear(torch.nn.Module, LoraLayer): + """LoRA layer implementation for quanto QLinear""" + + def __init__( + self, + base_layer, + adapter_name, + r: int = 0, + lora_alpha: int = 1, + lora_dropout: float = 0.0, + fan_in_fan_out: bool = False, # Set this to True if the layer to replace stores weight like (fan_in, fan_out) + init_lora_weights: bool = True, + use_rslora: bool = False, + use_dora: bool = False, + **kwargs, + ): + if use_dora: + raise ValueError(f"{self.__class__.__name__} does not support DoRA yet, please set it to False") + + super().__init__() + LoraLayer.__init__(self, base_layer) + self.fan_in_fan_out = fan_in_fan_out + + self._active_adapter = adapter_name + self.update_layer(adapter_name, r, lora_alpha, lora_dropout, init_lora_weights, use_rslora) + + def forward(self, x: torch.Tensor, *args: Any, **kwargs: Any) -> torch.Tensor: + result = self.base_layer(x) + adapter_names = kwargs.pop("adapter_names", None) + if adapter_names is not None: + raise ValueError(f"{self.__class__.__name__} does not support mixed_batch_forward yet.") + + if self.disable_adapters: + return result + + if self.disable_adapters: + if self.merged: + self.unmerge() + result = self.base_layer(x, *args, **kwargs) + elif self.merged: + result = self.base_layer(x, *args, **kwargs) + else: + for active_adapter in self.active_adapters: + if active_adapter not in self.lora_A.keys(): + continue + lora_A = self.lora_A[active_adapter] + lora_B = self.lora_B[active_adapter] + dropout = self.lora_dropout[active_adapter] + scaling = self.scaling[active_adapter] + + requires_conversion = not torch.is_autocast_enabled() + if requires_conversion: + expected_dtype = result.dtype + x = x.to(lora_A.weight.dtype) + + output = lora_B(lora_A(dropout(x))) + if requires_conversion: + output = output.to(expected_dtype) + output = output * scaling + result = result + output + + return result + + def get_delta_weight(self, adapter): + return ( + transpose( + self.lora_B[adapter].weight @ self.lora_A[adapter].weight, + fan_in_fan_out=self.fan_in_fan_out, + ) + * self.scaling[adapter] + ) + + def merge(self, safe_merge: bool = False, adapter_names: Optional[list[str]] = None) -> None: + from optimum.quanto import quantize_weight + + adapter_names = check_adapters_to_merge(self, adapter_names) + if not adapter_names: + # no adapter to merge + return + + base_layer = self.get_base_layer() + orig_weight = base_layer.weight + + for active_adapter in adapter_names: + delta_weight = self.get_delta_weight(active_adapter) + # note: no in-place for safe_merge=False + new_weight_data = orig_weight + delta_weight + if safe_merge: + if torch.isfinite(new_weight_data).all(): + raise ValueError(f"NaNs detected in the merged weights. The adapter {active_adapter} seems to be broken") + quantized = quantize_weight(new_weight_data, qtype=orig_weight.qtype, axis=orig_weight.axis) + base_layer.weight._data = quantized._data + base_layer.weight._scale = quantized._scale + self.merged_adapters.append(active_adapter) + + def unmerge(self) -> None: + from optimum.quanto import quantize_weight + + if not self.merged: + warnings.warn("Already unmerged. Nothing to do.") + return + + while len(self.merged_adapters) > 0: + active_adapter = self.merged_adapters.pop() + if active_adapter not in self.lora_A.keys(): + continue + + base_layer = self.get_base_layer() + orig_weight = base_layer.weight + new_weight_data = orig_weight - self.get_delta_weight(active_adapter) + quantized = quantize_weight(new_weight_data, qtype=orig_weight.qtype, axis=orig_weight.axis) + base_layer.weight._data = quantized._data + base_layer.weight._scale = quantized._scale + + def __repr__(self) -> str: + rep = super().__repr__() + return "lora." + rep + + +class QuantoLoraConv2d(torch.nn.Module, LoraLayer): + """LoRA layer implementation for quanto QConv2d""" + + def __init__( + self, + base_layer, + adapter_name, + r: int = 0, + lora_alpha: int = 1, + lora_dropout: float = 0.0, + init_lora_weights: bool = True, + use_rslora: bool = False, + use_dora: bool = False, + **kwargs, + ): + if use_dora: + raise ValueError(f"{self.__class__.__name__} does not support DoRA yet, please set it to False") + + super().__init__() + LoraLayer.__init__(self, base_layer) + + self._active_adapter = adapter_name + self.update_layer(adapter_name, r, lora_alpha, lora_dropout, init_lora_weights, use_rslora) + + def update_layer( + self, + adapter_name, + r, + lora_alpha, + lora_dropout, + init_lora_weights, + use_rslora, + use_dora, + ): + # same as lora.layer.Conv2d + if r <= 0: + raise ValueError(f"`r` should be a positive integer value but the value passed is {r}") + + self.r[adapter_name] = r + self.lora_alpha[adapter_name] = lora_alpha + if lora_dropout > 0.0: + lora_dropout_layer = nn.Dropout(p=lora_dropout) + else: + lora_dropout_layer = nn.Identity() + + self.lora_dropout[adapter_name] = lora_dropout_layer + # Actual trainable parameters + base_layer = self.get_base_layer() + kernel_size = base_layer.kernel_size + stride = base_layer.stride + padding = base_layer.padding + self.lora_A[adapter_name] = nn.Conv2d(self.in_features, r, kernel_size, stride, padding, bias=False) + self.lora_B[adapter_name] = nn.Conv2d(r, self.out_features, (1, 1), (1, 1), bias=False) + if use_rslora: + self.scaling[adapter_name] = lora_alpha / math.sqrt(r) + else: + self.scaling[adapter_name] = lora_alpha / r + + if init_lora_weights == "loftq": + self.loftq_init(adapter_name) + elif init_lora_weights: + self.reset_lora_parameters(adapter_name, init_lora_weights) + + # call this before dora_init + self._move_adapter_to_device_of_base_layer(adapter_name) + + if use_dora: + # TODO: Implement DoRA + self.dora_init(adapter_name) + self.use_dora[adapter_name] = True + else: + self.use_dora[adapter_name] = False + + self.set_adapter(self.active_adapters) + + def forward(self, x: torch.Tensor, *args: Any, **kwargs: Any) -> torch.Tensor: + result = self.base_layer(x) + adapter_names = kwargs.pop("adapter_names", None) + if adapter_names is not None: + raise ValueError(f"{self.__class__.__name__} does not support mixed_batch_forward yet.") + + if self.disable_adapters: + return result + + if self.disable_adapters: + if self.merged: + self.unmerge() + result = self.base_layer(x, *args, **kwargs) + elif self.merged: + result = self.base_layer(x, *args, **kwargs) + else: + for active_adapter in self.active_adapters: + if active_adapter not in self.lora_A.keys(): + continue + lora_A = self.lora_A[active_adapter] + lora_B = self.lora_B[active_adapter] + dropout = self.lora_dropout[active_adapter] + scaling = self.scaling[active_adapter] + + requires_conversion = not torch.is_autocast_enabled() + if requires_conversion: + expected_dtype = result.dtype + x = x.to(lora_A.weight.dtype) + + output = lora_B(lora_A(dropout(x))) + if requires_conversion: + output = output.to(expected_dtype) + output = output * scaling + result = result + output + + return result + + def get_delta_weight(self, adapter): + # same as lora.layer.Conv2d + device = self.lora_B[adapter].weight.device + dtype = self.lora_A[adapter].weight.dtype + + # In case users wants to merge the adapter weights that are in + # (b)float16 while being on CPU, we need to cast the weights to float32, perform the merge and then cast back to + # (b)float16 because some CPUs have slow bf16/fp16 matmuls. + cast_to_fp32 = device.type == "cpu" and (dtype == torch.float16 or dtype == torch.bfloat16) + + weight_A = self.lora_A[adapter].weight + weight_B = self.lora_B[adapter].weight + + if cast_to_fp32: + weight_A = weight_A.float() + weight_B = weight_B.float() + + # https://github.com/bmaltais/kohya_ss/blob/feb6728762a8f463d15ba936d189d4c3abfaa1ab/networks/lora.py#L117 + if self.get_base_layer().weight.size()[2:4] == (1, 1): + # conv2d 1x1 + output_tensor = (weight_B.squeeze(3).squeeze(2) @ weight_A.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze( + 3 + ) * self.scaling[adapter] + else: + # conv2d 3x3 + output_tensor = ( + F.conv2d( + weight_A.permute(1, 0, 2, 3), + weight_B, + ).permute(1, 0, 2, 3) + * self.scaling[adapter] + ) + + if cast_to_fp32: + output_tensor = output_tensor.to(dtype=dtype) + + # cast back the weights + self.lora_A[adapter].weight.data = weight_A.to(dtype) + self.lora_B[adapter].weight.data = weight_B.to(dtype) + + return output_tensor + + def merge(self, safe_merge: bool = False, adapter_names: Optional[list[str]] = None) -> None: + # same as lora.quanto.QuantoLoraLinear + from optimum.quanto import quantize_weight + + adapter_names = check_adapters_to_merge(self, adapter_names) + if not adapter_names: + # no adapter to merge + return + + base_layer = self.get_base_layer() + orig_weight = base_layer.weight + + for active_adapter in adapter_names: + delta_weight = self.get_delta_weight(active_adapter) + # note: no in-place for safe_merge=False + new_weight_data = orig_weight + delta_weight + if safe_merge: + if torch.isfinite(new_weight_data).all(): + raise ValueError(f"NaNs detected in the merged weights. The adapter {active_adapter} seems to be broken") + quantized = quantize_weight(new_weight_data, qtype=orig_weight.qtype, axis=orig_weight.axis) + base_layer.weight._data = quantized._data + base_layer.weight._scale = quantized._scale + self.merged_adapters.append(active_adapter) + + def unmerge(self) -> None: + # same as lora.quanto.QuantoLoraLinear + from optimum.quanto import quantize_weight + + if not self.merged: + warnings.warn("Already unmerged. Nothing to do.") + return + + while len(self.merged_adapters) > 0: + active_adapter = self.merged_adapters.pop() + if active_adapter not in self.lora_A.keys(): + continue + + base_layer = self.get_base_layer() + orig_weight = base_layer.weight + new_weight_data = orig_weight - self.get_delta_weight(active_adapter) + quantized = quantize_weight(new_weight_data, qtype=orig_weight.qtype, axis=orig_weight.axis) + base_layer.weight._data = quantized._data + base_layer.weight._scale = quantized._scale + + def __repr__(self) -> str: + rep = super().__repr__() + return "lora." + rep + + +def dispatch_quanto( + target: torch.nn.Module, + adapter_name: str, + **kwargs: Any, +) -> Optional[torch.nn.Module]: + new_module = None + + if isinstance(target, BaseTunerLayer): + target_base_layer = target.get_base_layer() + else: + target_base_layer = target + + if is_quanto_available() and isinstance(target_base_layer, QLinear): + new_module = QuantoLoraLinear(target, adapter_name, **kwargs) + target.weight = target_base_layer.weight + + if hasattr(target, "bias"): + target.bias = target_base_layer.bias + elif is_quanto_available() and isinstance(target_base_layer, QConv2d): + new_module = QuantoLoraConv2d(target, adapter_name, **kwargs) + target.weight = target_base_layer.weight + + if hasattr(target, "bias"): + target.bias = target_base_layer.bias + + return new_module + + +custom_module_mapping = {QConv2d: QuantoLoraConv2d, QLinear: QuantoLoraLinear} + +try: + # peft 版本里通常有这个全局映射 + from peft.tuners.lora.layer import LORA_MODULES_MAPPING + + if is_quanto_available: + if QLinear is not None: + LORA_MODULES_MAPPING[QLinear] = QuantoLoraLinear + if QConv2d is not None: + LORA_MODULES_MAPPING[QConv2d] = QuantoLoraConv2d +except Exception: + # 如果 peft 版本不匹配,就静默跳过,保持现有行为 + pass \ No newline at end of file diff --git a/diffsynth/utils/quantisation/quanto_workarounds.py b/diffsynth/utils/quantisation/quanto_workarounds.py new file mode 100644 index 00000000..5ba902aa --- /dev/null +++ b/diffsynth/utils/quantisation/quanto_workarounds.py @@ -0,0 +1,174 @@ +import optimum +import torch +from optimum.quanto.tensor.packed import PackedTensor +from optimum.quanto.tensor.weights.qbits import WeightQBitsTensor +from optimum.quanto.tensor.weights.qbytes import WeightQBytesTensor + +_TORCH_TENSOR_DATA_DESCRIPTOR = torch.Tensor.data + +if torch.cuda.is_available(): + # the marlin fp8 kernel needs some help with dtype casting for some reason + # see: https://github.com/huggingface/optimum-quanto/pull/296#issuecomment-2380719201 + if torch.device("cuda").type == "cuda" and torch.version.cuda: + from optimum.quanto.library.extensions.cuda import ext as quanto_ext + + # Save the original operator + original_gemm_f16f8_marlin = torch.ops.quanto.gemm_f16f8_marlin + + def fp8_marlin_gemm_wrapper( + a: torch.Tensor, + b_q_weight: torch.Tensor, + b_scales: torch.Tensor, + workspace: torch.Tensor, + num_bits: int, + size_m: int, + size_n: int, + size_k: int, + ) -> torch.Tensor: + # Ensure 'a' has the correct dtype + a = a.to(b_scales.dtype) + # Call the original operator + return original_gemm_f16f8_marlin( + a, + b_q_weight, + b_scales, + workspace, + num_bits, + size_m, + size_n, + size_k, + ) + + # Monkey-patch the operator + torch.ops.quanto.gemm_f16f8_marlin = fp8_marlin_gemm_wrapper + + class TinyGemmQBitsLinearFunction(optimum.quanto.tensor.function.QuantizedLinearFunction): + @staticmethod + def forward(ctx, input, other, bias): + ctx.save_for_backward(input, other) + if type(input) is not torch.Tensor: + input = input.dequantize() + in_features = input.shape[-1] + out_features = other.shape[0] + output_shape = input.shape[:-1] + (out_features,) + output = torch._weight_int4pack_mm( + input.view(-1, in_features).to(dtype=other.dtype), + other._data._data, + other._group_size, + other._scale_shift, + ) + output = output.view(output_shape) + if bias is not None: + output = output + bias + return output + + from optimum.quanto.tensor.weights import tinygemm + + tinygemm.qbits.TinyGemmQBitsLinearFunction = TinyGemmQBitsLinearFunction + + +class WeightQBytesLinearFunction(optimum.quanto.tensor.function.QuantizedLinearFunction): + @staticmethod + def forward(ctx, input, other, bias=None): + ctx.save_for_backward(input, other) + input_device = getattr(input, "device", None) + if input_device is None and hasattr(input, "_data"): + input_device = input._data.device + + if input_device is not None and hasattr(other, "_data"): + backing_data = other._data + backing_scale = getattr(other, "_scale", None) + if backing_data.device != input_device: + other._data = backing_data.to(input_device, non_blocking=True) + if backing_scale is not None and hasattr(backing_scale, "device") and backing_scale.device != input_device: + other._scale = backing_scale.to(input_device, non_blocking=True) + + if isinstance(input, optimum.quanto.tensor.QBytesTensor): + output = torch.ops.quanto.qbytes_mm(input._data, other._data, input._scale * other._scale) + else: + in_features = input.shape[-1] + out_features = other.shape[0] + output_shape = input.shape[:-1] + (out_features,) + output = torch.ops.quanto.qbytes_mm(input.reshape(-1, in_features), other._data, other._scale) + output = output.view(output_shape) + if bias is not None: + output = output + bias + return output + + +optimum.quanto.tensor.weights.qbytes.WeightQBytesLinearFunction = WeightQBytesLinearFunction + + +def reshape_qlf_backward(ctx, gO): + # another one where we need .reshape instead of .view + input_gO = other_gO = bias_gO = None + input, other = ctx.saved_tensors + out_features, in_features = other.shape + if ctx.needs_input_grad[0]: + # grad(A@(B.t()) = gO => grad(A) = gO@(B.t().t()) = gO@B + input_gO = torch.matmul(gO, other) + if ctx.needs_input_grad[1]: + # grad(B@A.t()) = gO.t() => grad(B) = gO.t()@(A.t().t()) = gO.t()@A + other_gO = torch.matmul( + gO.reshape(-1, out_features).t(), + input.to(gO.dtype).reshape(-1, in_features), + ) + if ctx.needs_input_grad[2]: + # Bias gradient is the sum on all dimensions but the last one + dim = tuple(range(gO.ndim - 1)) + bias_gO = gO.sum(dim) + return input_gO, other_gO, bias_gO + + +optimum.quanto.tensor.function.QuantizedLinearFunction.backward = reshape_qlf_backward + + +def _bridge_storage_accessors(tensor_cls, data_attr: str) -> None: + if getattr(tensor_cls, "_simpletuner_storage_bridge_applied", False): + return + + def _backing_tensor(self): + backing = getattr(self, data_attr, None) + if backing is None: + raise AttributeError(f"{tensor_cls.__name__} is missing expected backing tensor '{data_attr}'") + return backing + + def _data_ptr(self): + return _backing_tensor(self).data_ptr() + + def _untyped_storage(self): + return _backing_tensor(self).untyped_storage() + + def _storage(self): + return _backing_tensor(self).storage() + + tensor_cls.data_ptr = _data_ptr # type: ignore[assignment] + tensor_cls.untyped_storage = _untyped_storage # type: ignore[assignment] + tensor_cls.storage = _storage # type: ignore[assignment] + tensor_cls._simpletuner_storage_bridge_applied = True # type: ignore[attr-defined] + + +_bridge_storage_accessors(WeightQBytesTensor, "_data") +_bridge_storage_accessors(WeightQBitsTensor, "_data") +_bridge_storage_accessors(PackedTensor, "_data") + + +def _mirror_tensor_data_property(tensor_cls, attrs: tuple[str, ...]) -> None: + if getattr(tensor_cls, "_simpletuner_data_bridge_applied", False): + return + + def _data_get(self): + return _TORCH_TENSOR_DATA_DESCRIPTOR.__get__(self, type(self)) + + def _data_set(self, value): + _TORCH_TENSOR_DATA_DESCRIPTOR.__set__(self, value) + for attr in attrs: + if hasattr(value, attr) and hasattr(self, attr): + setattr(self, attr, getattr(value, attr)) + + tensor_cls.data = property(_data_get, _data_set) # type: ignore[assignment] + tensor_cls._simpletuner_data_bridge_applied = True # type: ignore[attr-defined] + + +_mirror_tensor_data_property(WeightQBytesTensor, ("_data", "_scale", "activation_qtype", "_axis", "_qtype")) +_mirror_tensor_data_property(WeightQBitsTensor, ("_data", "_scale", "_shift", "_axis", "_qtype")) diff --git a/examples/qwen_image/model_training/train.py b/examples/qwen_image/model_training/train.py index 6a0e4b6a..72171eb7 100644 --- a/examples/qwen_image/model_training/train.py +++ b/examples/qwen_image/model_training/train.py @@ -1,7 +1,16 @@ import torch, os, argparse, accelerate + +# 触发 quanto / peft 的 workarounds +import diffsynth.utils.quantisation.quanto_workarounds # noqa: F401 +import diffsynth.utils.quantisation.peft_workarounds # noqa: F401 + from diffsynth.core import UnifiedDataset from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig from diffsynth.diffusion import * + +from diffsynth.utils.quantisation import quantise_model, _quanto_model +from types import SimpleNamespace + os.environ["TOKENIZERS_PARALLELISM"] = "false" @@ -20,13 +29,76 @@ def __init__( offload_models=None, device="cpu", task="sft", + base_model_precision: str = "no_change", + text_encoder_1_precision: str = "no_change", + quantize_activations: bool = False, + result_image_field_name: str = "result_image", + quantize_vae: bool = False, # 可能会在非常细腻的纹理上有轻微劣化(通常可接受) ): super().__init__() # Load models model_configs = self.parse_model_configs(model_paths, model_id_with_origin_paths, fp8_models=fp8_models, offload_models=offload_models, device=device) tokenizer_config = ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="tokenizer/") if tokenizer_path is None else ModelConfig(tokenizer_path) processor_config = ModelConfig(model_id="Qwen/Qwen-Image-Edit", origin_file_pattern="processor/") if processor_path is None else ModelConfig(processor_path) - self.pipe = QwenImagePipeline.from_pretrained(torch_dtype=torch.bfloat16, device=device, model_configs=model_configs, tokenizer_config=tokenizer_config, processor_config=processor_config) + + # 是否启用 quanto:只要 base_model_precision 或 text_encoder_1_precision 包含 "quanto" + use_quanto = ( + (base_model_precision is not None and "quanto" in base_model_precision.lower()) + or (text_encoder_1_precision is not None and "quanto" in text_encoder_1_precision.lower()) + ) + + load_device = "cpu" if use_quanto else device + load_dtype = torch.bfloat16 + + # 1. 先在 load_device 上加载整条 pipeline + self.pipe = QwenImagePipeline.from_pretrained( + torch_dtype=load_dtype, + device=load_device, + model_configs=model_configs, + tokenizer_config=tokenizer_config, + processor_config=processor_config, + ) + + # 2. 如果启用 quanto,对 DiT + 文本编码器做 SimpleTuner 风格量化 + if use_quanto: + fake_args = SimpleNamespace( + base_model_precision=base_model_precision, + text_encoder_1_precision=text_encoder_1_precision, + text_encoder_2_precision="no_change", + text_encoder_3_precision="no_change", + text_encoder_4_precision="no_change", + quantize_activations=quantize_activations, + ) + + dit, text_encoders, _, _ = quantise_model( + model=self.pipe.dit, + text_encoders=[self.pipe.text_encoder], + controlnet=None, + ema=None, + args=fake_args, + ) + self.pipe.dit = dit + if text_encoders is not None and len(text_encoders) > 0: + self.pipe.text_encoder = text_encoders[0] + + # 2) 额外:量化 VAE + if quantize_vae: + if hasattr(self.pipe, "vae") and self.pipe.vae is not None: + # 对 VAE 用和 base 模型同样的 precision / 配置 + self.pipe.vae = _quanto_model( + model=self.pipe.vae, + model_precision=base_model_precision, # 通常是 "int8-quanto" + base_model_precision=base_model_precision, # 用于 fallback 判定 + quantize_activations=quantize_activations, # 建议 False,先只量化权重 + ) + + + # 3. 把整个 pipeline 挪到 accelerator.device + if load_device != device: + self.pipe.to(device) + self.pipe.device = device + + # 4. 保持原来的 split + peft LoRA 逻辑 self.pipe = self.split_pipeline_units(task, self.pipe, trainable_models, lora_base_model) # Training mode @@ -51,6 +123,7 @@ def __init__( "direct_distill": lambda pipe, inputs_shared, inputs_posi, inputs_nega: DirectDistillLoss(pipe, **inputs_shared, **inputs_posi), "direct_distill:train": lambda pipe, inputs_shared, inputs_posi, inputs_nega: DirectDistillLoss(pipe, **inputs_shared, **inputs_posi), } + self.result_image_field_name = result_image_field_name def get_pipeline_inputs(self, data): inputs_posi = {"prompt": data["prompt"]} @@ -58,9 +131,9 @@ def get_pipeline_inputs(self, data): inputs_shared = { # Assume you are using this pipeline for inference, # please fill in the input parameters. - "input_image": data["image"], - "height": data["image"].size[1], - "width": data["image"].size[0], + "input_image": data[self.result_image_field_name], + "height": data[self.result_image_field_name].size[1], + "width": data[self.result_image_field_name].size[0], # Please do not modify the following parameters # unless you clearly know what this will cause. "cfg_scale": 1, @@ -86,7 +159,49 @@ def qwen_image_parser(): parser = add_general_config(parser) parser = add_image_size_config(parser) parser.add_argument("--tokenizer_path", type=str, default=None, help="Path to tokenizer.") - parser.add_argument("--processor_path", type=str, default=None, help="Path to the processor. If provided, the processor will be used for image editing.") + parser.add_argument("--processor_path", type=str, default=None, help="Path to the processor.") + + # 和 SimpleTuner 对齐的量化参数 + parser.add_argument( + "--base_model_precision", + type=str, + default="no_change", + choices=[ + "no_change", + "fp32", + "fp16", + "bf16", + "int2-quanto", + "int4-quanto", + "int8-quanto", + "fp8-quanto", + "fp8uz-quanto", + ], + help="Precision for DiT / main diffusion model. Use '*-quanto' to enable optimum.quanto.", + ) + parser.add_argument( + "--text_encoder_1_precision", + type=str, + default="no_change", + help="Precision for the first text encoder. Defaults to no_change (i.e. bf16), like SimpleTuner configs.", + ) + parser.add_argument( + "--quantize_activations", + action="store_true", + help="When using quanto, also quantize activations in addition to weights.", + ) + parser.add_argument( + "--result_image_field_name", + type=str, + default="result_image", + help="The field name of the image generated by the model in the dataset JSON.", + ) + parser.add_argument( + "--quantize_vae", + action="store_true", + help="Enabling quantized VAEs may result in slight degradation in very fine textures (generally acceptable).", + ) + return parser @@ -130,6 +245,11 @@ def qwen_image_parser(): offload_models=args.offload_models, task=args.task, device=accelerator.device, + base_model_precision=args.base_model_precision, + text_encoder_1_precision=args.text_encoder_1_precision, + quantize_activations=args.quantize_activations, + result_image_field_name=args.result_image_field_name, + quantize_vae=args.quantize_vae, ) model_logger = ModelLogger( args.output_path, diff --git a/pyproject.toml b/pyproject.toml index cb00b4d3..a23061ff 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -25,6 +25,7 @@ dependencies = [ "accelerate", "peft", "datasets", + "optimum[quanto]", ] classifiers = [ "Programming Language :: Python :: 3",