WARNING: THIS SITE IS A MIRROR OF GITHUB.COM / IT CANNOT LOGIN OR REGISTER ACCOUNTS / THE CONTENTS ARE PROVIDED AS-IS / THIS SITE ASSUMES NO RESPONSIBILITY FOR ANY DISPLAYED CONTENT OR LINKS / IF YOU FOUND SOMETHING MAY NOT GOOD FOR EVERYONE, CONTACT ADMIN AT ilovescratch@foxmail.com
Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
200 changes: 200 additions & 0 deletions diffsynth/utils/quantisation/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,200 @@
import logging
import torch

logger = logging.getLogger(__name__)


def _quanto_type_map(model_precision: str):
if model_precision is None or model_precision == "no_change":
return None
from optimum.quanto import qfloat8, qfloat8_e4m3fnuz, qint2, qint4, qint8

mp = model_precision.lower()
if mp == "int2-quanto":
return qint2
elif mp == "int4-quanto":
return qint4
elif mp == "int8-quanto":
return qint8
elif mp in ("fp8-quanto", "fp8uz-quanto"):
if torch.backends.mps.is_available():
logger.warning(
"MPS doesn't support dtype float8, please use bf16/fp16/int8-quanto instead."
)
return None
return qfloat8 if mp == "fp8-quanto" else qfloat8_e4m3fnuz
else:
raise ValueError(f"Invalid quantisation level: {model_precision}")


def _quanto_model(
model,
model_precision,
base_model_precision=None,
quantize_activations: bool = False,
):
try:
from optimum.quanto import quantize, freeze # noqa
# 仅仅 import,就会触发 quanto_workarounds 里的 monkeypatch
from diffsynth.utils.quantisation import quanto_workarounds # noqa: F401
except ImportError as e:
raise ImportError(
"To use Quanto, please install the optimum library: `pip install \"optimum[quanto]\"`"
) from e

if model is None:
return model
if model_precision is None:
model_precision = base_model_precision
if model_precision in (None, "no_change"):
logger.info("...No quantisation applied to %s.", model.__class__.__name__)
return model

logger.info("Quantising %s. Using %s.", model.__class__.__name__, model_precision)
weight_quant = _quanto_type_map(model_precision)
if weight_quant is None:
logger.info("Quantisation level %s resolved to None, skipping.", model_precision)
return model

extra_quanto_args = {}
if quantize_activations:
logger.info("Quanto: Freezing model weights and activations")
extra_quanto_args["activations"] = weight_quant
else:
logger.info("Quanto: Freezing model weights only")

quantize(model, weights=weight_quant, **extra_quanto_args)
freeze(model)
return model


def get_quant_fn(base_model_precision):
if base_model_precision is None:
return None
precision = base_model_precision.lower()
if precision == "no_change":
return None
if "quanto" in precision:
return _quanto_model
# 这里先不支持 torchao
return None


def quantise_model(
model=None,
text_encoders: list = None,
controlnet=None,
ema=None,
args=None,
return_dict: bool = False,
):
# 展开 text_encoders,最多支持 4 个以兼容 SimpleTuner 的接口
te1 = te2 = te3 = te4 = None
if text_encoders is not None:
if len(text_encoders) > 0:
te1 = text_encoders[0]
if len(text_encoders) > 1:
te2 = text_encoders[1]
if len(text_encoders) > 2:
te3 = text_encoders[2]
if len(text_encoders) > 3:
te4 = text_encoders[3]
Comment on lines +92 to +101
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

这部分展开 text_encoders 的代码有些冗余,并且硬编码了最多4个编码器。可以将其简化,使其更具可读性和扩展性。

    te = [None] * 4
    if text_encoders is not None:
        num_tes = min(len(text_encoders), 4)
        te[:num_tes] = text_encoders[:num_tes]
    te1, te2, te3, te4 = te


models = [
(
model,
{
"quant_fn": get_quant_fn(args.base_model_precision),
"model_precision": args.base_model_precision,
"quantize_activations": getattr(args, "quantize_activations", False),
},
),
(
controlnet,
{
"quant_fn": get_quant_fn(args.base_model_precision),
"model_precision": args.base_model_precision,
"quantize_activations": getattr(args, "quantize_activations", False),
},
),
(
te1,
{
"quant_fn": get_quant_fn(args.text_encoder_1_precision),
"model_precision": args.text_encoder_1_precision,
"base_model_precision": args.base_model_precision,
},
),
(
te2,
{
"quant_fn": get_quant_fn(args.text_encoder_2_precision),
"model_precision": args.text_encoder_2_precision,
"base_model_precision": args.base_model_precision,
},
),
(
te3,
{
"quant_fn": get_quant_fn(args.text_encoder_3_precision),
"model_precision": args.text_encoder_3_precision,
"base_model_precision": args.base_model_precision,
},
),
(
te4,
{
"quant_fn": get_quant_fn(args.text_encoder_4_precision),
"model_precision": args.text_encoder_4_precision,
"base_model_precision": args.base_model_precision,
},
),
(
ema,
{
"quant_fn": get_quant_fn(args.base_model_precision),
"model_precision": args.base_model_precision,
"quantize_activations": getattr(args, "quantize_activations", False),
},
),
]
Comment on lines +103 to +160
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

models 列表的构建方式包含大量重复代码,这使得代码难以阅读和维护。此外,它依赖于调用者(train.py)提供一个包含所有 text_encoder_*_precision 属性的 args 对象,导致了不必要的强耦合(见 train.pyfake_args 的实现)。

建议将此部分重构为一个配置列表和循环,并使用 getattr 的默认值来增强函数的健壮性,从而解耦此工具函数与具体的训练脚本。

    model_definitions = [
        {"model": model, "precision_key": "base_model_precision", "quantize_activations": True},
        {"model": controlnet, "precision_key": "base_model_precision", "quantize_activations": True},
        {"model": te1, "precision_key": "text_encoder_1_precision"},
        {"model": te2, "precision_key": "text_encoder_2_precision"},
        {"model": te3, "precision_key": "text_encoder_3_precision"},
        {"model": te4, "precision_key": "text_encoder_4_precision"},
        {"model": ema, "precision_key": "base_model_precision", "quantize_activations": True},
    ]

    models = []
    for definition in model_definitions:
        m = definition["model"]
        precision_key = definition["precision_key"]
        # Use getattr with a default to make this function more robust
        model_precision = getattr(args, precision_key, "no_change")
        
        qargs = {
            "quant_fn": get_quant_fn(model_precision),
            "model_precision": model_precision,
            "base_model_precision": args.base_model_precision,
        }
        if definition.get("quantize_activations"):
            qargs["quantize_activations"] = getattr(args, "quantize_activations", False)
        
        models.append((m, qargs))


for i, (m, qargs) in enumerate(models):
quant_fn = qargs["quant_fn"]
if m is None or quant_fn is None:
continue
quant_args_combined = {
"model_precision": qargs["model_precision"],
"base_model_precision": qargs.get("base_model_precision", args.base_model_precision),
"quantize_activations": qargs.get(
"quantize_activations", getattr(args, "quantize_activations", False)
),
}
logger.info("Quantising %s with %s", m.__class__.__name__, quant_args_combined)
models[i] = (quant_fn(m, **quant_args_combined), qargs)

# 解包
model, controlnet, te1, te2, te3, te4, ema = [m for (m, _) in models]

# 重新打包 text_encoders
new_text_encoders = []
if te1 is not None:
new_text_encoders.append(te1)
if te2 is not None:
new_text_encoders.append(te2)
if te3 is not None:
new_text_encoders.append(te3)
if te4 is not None:
new_text_encoders.append(te4)
if len(new_text_encoders) == 0:
new_text_encoders = None
Comment on lines +180 to +190
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

这部分重新打包 text_encoders 的代码也有些冗余。可以使用列表推导式来简化,使其更简洁。

    new_text_encoders = [te for te in [te1, te2, te3, te4] if te is not None]
    if not new_text_encoders:
        new_text_encoders = None


if return_dict:
return {
"model": model,
"text_encoders": new_text_encoders,
"controlnet": controlnet,
"ema": ema,
}

return model, new_text_encoders, controlnet, ema
Loading