开发指南
浮点模型的要求
symbolic_trace
和 PyTorch 的量化训练类似,horizon_plugin_pytorch 基于 fx 设计和开发,因此,要求浮点模型必须是可以正确的完成 symbolic_trace 的
仅支持部分算子
由于 BPU 只支持数量有限的算子,因此,horizon_plugin_pytorch 只支持算子列表中的算子和基于 BPU 限制而内部特殊定义的特殊算子。
构建量化友好模型
浮点模型变为定点模型的过程存在一定的精度误差,越是量化友好的浮点模型, qat 精度提升越容易,量化后的精度也越高。一般而言,有以下几种情况会导致模型变得量化不友好:
-
使用有精度风险的算子。例如: softmax , layernorm 等(详见 op 文档),这类算子一般底层由查表或多个 op 拼接实现,容易发生掉点问题。
-
一次 forward 中多次调用同一算子。同一算子多次调用,对应的输出分布存在差异,但只会统计一组量化参数,当多次调用的输出分布差异过大时,量化误差会变大。
-
add , cat 等多输入算子的不同输入差异过大,可能造成较大误差。
-
数据分布不合理。plugin 采用的是均匀对称量化,所以 0 均值的均匀分布最好,应 尽量避免长尾和离群点。同时,数值范围需要与量化 bit 相匹配,如果使用int8量化分布为 [-1000, 1000] 均匀分布的数据,那么精度显然也是不够的。例如,下面三个分布图,从左到右对量化的友好性依次递减,模型中大部分数值的分布应当为中间这种分布。在实际使用中,可以用 debug 工具查看模型 weight 和 feature map 的分布是否量化友好。因为模型冗余性的存在,有些看起来分布非常量化不友好的 op 并不会显著降低模型的最终精度,需要结合实际的 qat 训练难度和最后达到的量化精度综合考虑。
那么如何使得模型更加量化友好呢?具体来说:
-
尽量少使用精度风险过大的算子,详见 op 文档。
-
保证多次调用的共享算子每次调用的输出分布差异不要太大,或者将共享算子拆开分别单独使用。
-
避免多输入算子不同输入的数值范围差异过大。
-
使用 int16 量化数值范围和误差都非常大的 op 。可通过 debug 工具找到这类 op 。
-
通过调大 weight decay ,增加数据增强等方式防止模型过拟合。过拟合模型容易出现较大数值,且对输入非常敏感,轻微的误差可能导致输出完全错误。
-
使用 BN 。
-
对模型输入做关于0对称的归一化。
需要注意的是, qat 自身具有一定的调整能力,量化不友好并不代表不能量化,很多情况下,即使出现上面的不适合量化的现象,仍然可以量化得很好。因为上述建议也可能会导致浮点模型精度下降,所以应当在 qat 精度无法达标时再尝试上述 建议,尤其是 1 - 5 条建议,最后应当是在浮点模型精度和量化模型精度中找一个平衡点。
qconfig 详解
什么是 qconfig
模型的量化方式由 qconfig 决定,在准备 qat / calibration 模型之前,需要先给模型设置 qconfig。我们不推荐您自定义 qconfig,尽量只使用预定义好的qconfig变量,因为自定义 qconfig 需要对具体的处理器限制认知清晰,详细了解训练工具的工作原理,定义出错可能导致模型无法正常收敛、模型无法编译等问题,浪费大量时间和人力。
目前,Plugin 中维护了两个版本的qconfig,早期版本的 qconfig 将在不久的将来被废弃,我们只推荐您使用此文档中介绍的 qconfig 用法。
如何获取 qconfig
- 使用封装好的 qconfig 变量。这些 qconfig 存放在
horizon_plugin_pytorch/quantization/qconfig.py
中,可以适用于 绝大多数情况。包括:
from horizon_plugin_pytorch.quantization.qconfig import (
default_calib_8bit_fake_quant_qconfig,
default_qat_8bit_fake_quant_qconfig,
default_qat_8bit_fixed_act_fake_quant_qconfig,
default_calib_8bit_weight_16bit_act_fake_quant_qconfig,
default_qat_8bit_weight_16bit_act_fake_quant_qconfig,
default_qat_8bit_weight_16bit_fixed_act_fake_quant_qconfig,
default_qat_8bit_weight_32bit_out_fake_quant_qconfig, # 参考算子列表,支持高精度输出的算子可以设置此 qconfig 获得更高的精度
default_calib_8bit_weight_32bit_out_fake_quant_qconfig, # 参考算子列表,支持高精度输出的算子可以设置此 qconfig 获得更高的精度
)
- 使用
get_default_qconfig
接口。此接口较固定 qconfig 变量更灵活,我们推荐您对量化和硬件限制有清晰认知之后再使用。常用参数和解释如下:
from horizon_plugin_pytorch.quantization.qconfig import get_default_qconfig
qconfig = get_default_qconfig(
activation_fake_quant="fake_quant", # 支持 fake_quant, lsq, pact,常用 fake quant
weight_fake_quant="fake_quant", # 支持 fake_quant, lsq, pact,常用 fake quant
activation_observer="min_max", # 支持 min_max, fixed_scale, clip, percentile, clip_std, mse, kl
weight_observer="min_max", # 支持 min_max, fixed_scale, clip, percentile, clip_std, mse, kl
activation_qkwargs={
"dtype": qint16, # 由具体算子决定是否支持 int16
"is_sync_quantize": False, # 是否同步统计数据,默认关闭提升forward速度
"averaging_constant": 0.01 # 滑动平均系数,设置为0时,scale不更新
},
weight_qkwargs={ # 只支持 dtype = qint8, qscheme = torch.per_channel_symmetric, ch_axis = 0, 不建议做额外配置
"dtype": qint8,
"qscheme": torch.per_channel_symmetric,
"ch_axis": 0,
},
)
如何设置 qconfig
共有三种设置方法,我们推荐您使用前两种,最后一种设置方式将废弃。
- 直接设置 qconfig 属性。此方法优先级最高,其余方法不会覆盖直接设置的 qconfig。
model.qconfig = default_qat_8bit_fake_quant_qconfig
- qconfig 模板。在 prepare 接口上指定 qconfig setter 和 example_inputs,自动为模型设置 qconfig。
model = prepare_qat_fx(
model,
example_inputs=data,
qconfig_setter=default_qat_qconfig_setter,
)
- qconfig_dict。在 prepare_qat_fx 接口上指定 qconfig_dict。此用法将逐步废弃,如无兼容性需求,不推荐再使用,这里不展开介绍。
model = prepare_qat_fx(
model,
qconfig_dict={"": default_qat_qconfig_setter},
)
qconfig 模板
长期以来,配置 qconfig 出错的问题经常发生,因此我们开发了 qconfig 模板。qconfig 模板基于 subclass trace 方案感知模型的图结构,并按设定的规则自动设置 qconfig,是我们最推荐的设置 qconfig 方法。用法如下:
qat_model = prepare_qat_fx(
model,
example_inputs=example_input, # 用来感知图结构
qconfig_setter=( # qconfig 模板,支持传入多个模板,优先级从高到低。
sensitive_op_qat_8bit_weight_16bit_act_qconfig_setter(table, ratio=0.2),
default_calibration_qconfig_setter,
)
)
模板的优先级低于直接给模型设置 qconfig 属性,如果模型在 prepare 之前已经使用 model.qconfig = xxx 进行了配置,那么模板将不会生效。如果没有特殊需求,我们不推荐将两者混合使用,这很容易引发低级错误。绝大多数情况下,我们推荐您使用模板和 model.qconfig = xxx 两种设置方式中的一种即可满足需求。
模板可分为三类:
- 固定模板。固定模板中 calibration / qat / qat_fixed_act_scale 区别在于使用的 observer 类型和 scale 更新逻辑,分别用于校准,qat 训练,固定 activation scale qat 训练。default 模板( default_calibration_qconfig_setter / default_qat_qconfig_setter / default_qat_fixed_act_qconfig_setter )会做三件事:首先,将可以设置的高精度输出都设置上,对于不支持高精度的输出将给出提示;然后,从 grid sample 算子的 grid 输入向前搜索,直到出现第一个 gemm 类算子或者QuantStub,将中间的所有算子都设置为 int16。根据经验这里的 grid 一般表达范围较宽,int8 有较大可能不满足精度需求;最后,将其余算子设置为 int8。int16 模板( qat_8bit_weight_16bit_act_qconfig_setter / qat_8bit_weight_16bit_fixed_act_qconfig_setter / calibration_8bit_weight_16bit_act_qconfig_setter )会做两件事:首先,将可以设置的高精度输出都设置上,对于不支持高精度的输出将给出提示;其次,将其余算子设置为 int16。
from horizon_plugin_pytorch.quantization.qconfig_template import (
default_calibration_qconfig_setter,
default_qat_qconfig_setter,
default_qat_fixed_act_qconfig_setter,
qat_8bit_weight_16bit_act_qconfig_setter,
qat_8bit_weight_16bit_fixed_act_qconfig_setter,
calibration_8bit_weight_16bit_act_qconfig_setter,
)
- 敏感度模板。敏感度模板有 sensitive_op_calibration_8bit_weight_16bit_act_qconfig_setter, sensitive_op_qat_8bit_weight_16bit_act_qconfig_setter, sensitive_op_qat_8bit_weight_16bit_fixed_act_qconfig_setter,三者的区别和固定模板中三者的区别一致,也是分别用于校准,qat 训练,固定 activation scale qat 训练。 敏感度模板的第一个输入是精度 debug 工具产生的敏感度结果,第二个参数可以指定 ratio 或 topk ,敏感度模板会将量化敏感度最高的 topk 个算子设置为 int16。搭配固定模板,可以轻松实现混合精度调优。
from horizon_plugin_pytorch.quantization.qconfig_template import (
default_calibration_qconfig_setter,
default_qat_qconfig_setter,
default_qat_fixed_act_qconfig_setter,
qat_8bit_weight_16bit_act_qconfig_setter,
qat_8bit_weight_16bit_fixed_act_qconfig_setter,
calibration_8bit_weight_16bit_act_qconfig_setter,
sensitive_op_qat_8bit_weight_16bit_act_qconfig_setter,
sensitive_op_qat_8bit_weight_16bit_fixed_act_qconfig_setter,
sensitive_op_calibration_8bit_weight_16bit_act_qconfig_setter,
)
table = torch.load("output_0-0_dataindex_1_sensitive_ops.pt")
qat_model = prepare_qat_fx(
model,
example_inputs=example_input,
qconfig_setter=(
sensitive_op_qat_8bit_weight_16bit_fixed_act_qconfig_setter(table, ratio=0.2),
default_calibration_qconfig_setter,
)
)
- 自定义模板。自定义模板只有 ModuleNameQconfigSetter,需要传入模块名和对应 qconfig 的字典,一般用于设置 fixed scale 等特殊需求,可以和固定模板,敏感度模板搭配使用。
from horizon_plugin_pytorch.quantization.qconfig_template import (
default_calibration_qconfig_setter,
default_qat_qconfig_setter,
default_qat_fixed_act_qconfig_setter,
qat_8bit_weight_16bit_act_qconfig_setter,
qat_8bit_weight_16bit_fixed_act_qconfig_setter,
calibration_8bit_weight_16bit_act_qconfig_setter,
sensitive_op_qat_8bit_weight_16bit_act_qconfig_setter,
sensitive_op_qat_8bit_weight_16bit_fixed_act_qconfig_setter,
sensitive_op_calibration_8bit_weight_16bit_act_qconfig_setter,
ModuleNameQconfigSetter,
)
table = torch.load("output_0-0_dataindex_1_sensitive_ops.pt")
module_name_to_qconfig = {
"op_1": default_qat_8bit_fake_quant_qconfig,
"op_2": get_default_qconfig(
activation_observer="fixed_scale",
activation_qkwargs={
"dtype": qint16,
"scale": OP2_MAX / QINT16_MAX,
},
)
}
qat_model = prepare_qat_fx(
model,
example_inputs=example_input,
qconfig_setter=(
ModuleNameQconfigSetter(module_name_to_qconfig),
sensitive_op_qat_8bit_weight_16bit_fixed_act_qconfig_setter(table, ratio=0.2),
default_calibration_qconfig_setter,
)
)
Calibration 指南
在量化中,一个重要的步骤是确定量化参数,合理的初始量化参数能够显著提升模型精度并加快模型的收敛速度。Calibration 就是在浮点模型中插入 Observer,使用少量训练数据,在模型 forward 过程中统计各处的数据分布,以确定合理的量化参数的过程。虽然不做 Calibration 也可以进行量化训练,但一般来说,它对量化训练有益无害,所以推荐用户将此步骤作为必选项。
流程和示例
Calibration 与 QAT 的整体流程如下图所示:
下面分别介绍各个步骤:
-
构建并训练浮点模型。参考 horizon_plugin_pytorch 快速入门章节中的 获取浮点模型 小节内容。
-
在浮点模型上插入 Observer 节点。参考 horizon_plugin_pytorch 快速入门章节中的 Calibration 小节内容。使用
prepare_qat_fx
方法转化浮点模型前,需要为模型设置qconfig
。model.qconfig = horizon.quantization.get_default_qconfig()
get_default_qconfig
可以为weight
和activation
设置不同的observer
。目前,calibration 可选observer
有 "min_max"、 "percentile"、 "mse"、 "kl" 和 "mix"。如无特殊需求,weight_observer
推荐使用默认的 "min_max",activation_observer
推荐使用 "mse"。特殊用法和调试技巧见下面的常见算法介绍。fake_quant
参数对 Calibration 结果无影响,保留默认状态即可。def get_default_qconfig(
activation_fake_quant: Optional[str] = "fake_quant",
weight_fake_quant: Optional[str] = "fake_quant",
activation_observer: Optional[str] = "min_max",
weight_observer: Optional[str] = "min_max",
activation_qkwargs: Optional[Dict] = None,
weight_qkwargs: Optional[Dict] = None,
): -
设置
fake quantize
状态为CALIBRATION
。horizon.quantization.set_fake_quantize(model, horizon.quantization.FakeQuantState.CALIBRATION)
fake quantize
一共有三种状态,分别需要在QAT
、calibration
、validation
前将模型的fake quantize
设置为对应的状态。在 calibration 状态下,仅观测各算子输入输出的统计量。在 QAT 状态下,除观测统计量外还会进行伪量化操作。而在 validation 状态下,不会观测统计量,仅进行伪量化操作。class FakeQuantState(Enum):
QAT = "qat"
CALIBRATION = "calibration"
VALIDATION = "validation" -
calibration。把准备好的校准数据喂给模型,模型在 forward 过程中由 observer 观测相关统计量。
-
设置模型状态为 eval 并设置
fake quantize
状态为VALIDATION
。model.eval()
horizon.quantization.set_fake_quantize(model, horizon.quantization.FakeQuantState.VALIDATION) -
验证
calibration
效果。如果效果满意,则可以直接将模型转为定点或在此基础上进行量化训练,不满意则调整calibration qconfig
中的参数继续 calibration。
常用算法介绍
有关每个算子的参数说明,请参考文末 API 文档。
算法 | 速度排名 | 精度排名 | 易用性排名 |
---|---|---|---|
min_max | 1 | 5 | 1 |
percentile | 2 | 4 | 4 |
mse | 4 | 1 | 2 |
kl | 5 | 2 | 3 |
mix | 3 | 2 | 1 |
常用的几种校准方法性能如上表所示,数字越小越好,速度表示相同数据校准耗时,精度表示该方法在大多数模型上的校准效果,易用性表示该方法的调参复杂度。
对于同一模型而言,不同方法不同参数的精度/速度会存在较大差别,最新的一些研究工作也表明,没有一种方法可以在所有模型上都取得最好的精度,需要针对地调整其参数。所以推荐用户对这几种校准方法都进行尝试。
-
min_max。此方法仅统计最大值最小值的滑动平均,用于快速确定 Batch size、average_constant 等通用参数,没有太多技巧。
-
percentile。此方法是所有方法中精度上限最高的,但也是调整起来最麻烦的,如果通过其他方法或本方法的默认参数就可以满足精度要求,那么不建议在调参上花太多时间。percentile 可调的参数一共有两个 bins、percentile。bins 越多,max 的候选项间隔越小,可供调整的粒度越细,但也意味着更高的计算耗时。建议先确定 percentile 再调整 bins,两者交替迭代缩小调参范围直至达到满意的效果。绝大部分情况下 bins 取 2048 提供的调整粒度完全足够,不需要单独调整这个参数。以下是一个模型的调参路径:
顺序 | percentile | bins | 精度 |
---|---|---|---|
1 | 99.99 | 2048 | 53.75 |
2 | 99.99 | 4096 | 54.38 |
3 | 99.995 | 4096 | 16.25 |
4 | 99.985 | 4096 | 32.67 |
5 | 99.9875 | 4096 | 57.06 |
6 | 99.9875 | 8192 | 62.84 |
7 | 99.98875 | 8192 | 57.62 |
8 | 99.988125 | 8192 | 63.15 |
在这个例子中,可以看到仔细调整后,精度提升了大约 10%。 模型中不同 op 的输入输出之间存在很大差异,一组全局的 percentile 参数可能很难满足所有 op 的需求,对精度要求较高时,可以先通过上面的方法找到较好的全局参数,再通过 debug 工具找到误差较大的几个 op,单独为这几个 op 设置 percentile 参数,设置方式参照 qconfig 设置。下面列举几种常见的容易导致误差较大的数据分布:
超长尾分布,percentile 的取值应当小一些,图中 99.9 是较好的取值。
值域过大,且分布并不集中在一处,这种情况无论是保留尾部还是忽略尾部都会带来较大的精度损失,应该在训练浮点模型时通过调整 weight decay 等参数避免这种情况的出现。
layernorm 的输出分布会呈现出若干集中度非常高的区域,此时 percentile 按照正常方法调整对于量化结果不会有任何影响,需要将 percentile 调整幅度增加。
-
mse。可调整的参数只有 stride,默认 stride 为 1,会逐步尝试最大值的 100 分位并选出量化反量化前后误差最小(L2 距离)的分位对应的值。此方法对大模型耗时较高,在合理范围内调大 stride 可以在保证精度的前提下减少耗时,stride 调整过大会影响精度。注意,调整此方法的参数只能优化耗时,并不能显著提升精度。
-
kl。可调的参数一共有两个 bin 和 update_interval。由于此方法耗时过长,不建议调整默认 bin,update_interval 默认为 1,调大可以减少耗时,但需要保证 update_interval 小于总的 calibration step,否则无法得到正常的量化参数。
-
mix。此方法为混合校准,对于每一个需要统计的地方,都会尝试 percentile 方法的不同参数,选出量化反量化前后误差最小(L2 距离)的方法。自动化程度较高,没有需要调整的参数。
调参技巧
-
calibration 数据越多越好,但因为边际效应的存在,当数据量大到一定程度后,对精度的提升将非常有限。如果训练集较小,可以全部用来 calibration,如果训练集较大,可以结合 calibration 耗时挑选大小合适的子集,建议至少进行 10 - 100 个 step 的校准。
-
数据可以做水平翻转这类 augmentation,不要做马赛克这种 augmentation。尽量使用 infer 阶段的前处理 + 训练数据进行校准。
-
Batch size 尽可能大,如果数据噪声较大或模型离群点较多,可以适当减小。此参数应当在尝试 min max 方法时确定。
-
average_constant 表示每个 step 对最大值最小值的影响,average_constant 越小,当前 step 的影响越小,历史滑动均值的影响越大。该参数需要结合数据量在 0.01 ~ 0.5 之间调整。当数据量充足时(step > 100),average_constant 取 0.01,数据量不足时,average_constant 酌情增加,极端情况下,只有 2 个 step 的数据,average_constant 取 0.5。此参数应当在尝试 min max 方法时确定,之后其他方法都沿用此参数。
-
calibration 模型精度较好时,固定 feature map 的量化参数进行 QAT 训练可以取得更好的效果,精度较差时,则不能固定 calibration 得到的量化参数。关于精度是好还是坏,没有明确的标准,需要去尝试。比如:某模型精度为 100,如果 calibration 精度为 50,那么精度肯定称不 上好,但如果 calibration 精度为 95,那么这个精度是否可以达到固定 feature map 量化参数的程度就需要尝试了,通常做法是固定与不固定都做实验进行对比。
-
优先尝试 min max 方法,该方法是速度最快的,用来跑通 calibration 流程,调整并确定 batch size 和 average_constant 两个参数,接着分别尝试 percentile、kl、mse 和 mix 四种方法并选取效果最好的方法。
Observer 参数文档
class horizon_plugin_pytorch.quantization.observer_v2.KLObserver(bins: int = 512, update_interval: int = 1, averaging_constant: float = 0.01, ch_axis: int = - 1, dtype: Union[torch.dtype, horizon_plugin_pytorch.dtype.QuantDType] = 'qint8', qscheme: torch.qscheme = torch.per_tensor_symmetric, quant_min: int = None, quant_max: int = None, is_sync_quantize: bool = False, factory_kwargs: Dict = None)
KL observer. KL observer based on histogram. Histogram is calculated online and won’t be saved.
参数
-
bins – Number of histograms bins.
-
update_interval – Interval of computing KL entropy and update min/max. KLObserver will constantly collect histograms of activations, but only perform KL calculation when update_interval is satisfied. if it is set to 1, KL entropy will be computed every forward step. Larger interval guarantees less time and does no harm to calibration accuracy. Set it to the total calibration steps can achieve best performance. update_interval must be no greater than total calibration steps, otherwise no min/max will be computed.
-
averaging_constant – Averaging constant for min/max.
-
ch_axis – Channel axis.
-
dtype – Quantized data type.
-
qscheme – Quantization scheme to be used.
-
quant_min – Min quantization value. Will follow dtype if unspecified.
-
quant_max – Max quantization value. Will follow dtype if unspecified.
-
is_sync_quantize – If sync statistics when training with multiple devices.
-
factory_kwargs – kwargs which are passed to factory functions for min_val and max_val.
forward(x_orig)
Defines the computation performed at every call.
Should be overridden by all subclasses.
Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
class horizon_plugin_pytorch.quantization.observer_v2.MSEObserver(stride: int = 1, averaging_constant: float = 0.01, ch_axis: int = - 1, dtype: Union[torch.dtype, horizon_plugin_pytorch.dtype.QuantDType] = 'qint8', qscheme: torch.qscheme = torch.per_tensor_symmetric, quant_min: int = None, quant_max: int = None, is_sync_quantize: bool = False, factory_kwargs: Dict = None)
MSE observer.
Observer module for computing the quantization parameters based on the Mean Square Error (MSE) between the original tensor and the quantized one.
This observer linear searches the quantization scales that minimize MSE.
参数
-
stride – Searching stride. Larger value gives smaller search space, which means less computing time but possibly poorer accuracy. Default is 1. Suggests no greater than 20.
-
averaging_constant – Averaging constant for min/max.
-
ch_axis – Channel axis.
-
dtype – Quantized data type.
-
qscheme – Quantization scheme to be used.
-
quant_min – Min quantization value. Will follow dtype if unspecified.
-
quant_max – Max quantization value. Will follow dtype if unspecified.
-
is_sync_quantize – If sync statistics when training with multiple devices.
-
factory_kwargs – kwargs which are passed to factory functions for min_val and max_val.
forward(x_orig)
Defines the computation performed at every call.
Should be overridden by all subclasses.
Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
class horizon_plugin_pytorch.quantization.observer_v2.MinMaxObserver(averaging_constant: float = 0.01, ch_axis: int = - 1, dtype: Union[torch.dtype, horizon_plugin_pytorch.dtype.QuantDType] = 'qint8', qscheme: torch.qscheme = torch.per_tensor_symmetric, quant_min: int = None, quant_max: int = None, is_sync_quantize: bool = False, factory_kwargs: Dict = None)
Min max observer.
This observer computes the quantization parameters based on minimums and maximums of the incoming tensors. The module records the moving average minimum and maximum of incoming tensors, and uses this statistic to compute the quantization parameters.
参数
-
averaging_constant – Averaging constant for min/max.
-
ch_axis – Channel axis.
-
dtype – Quantized data type.
-
qscheme – Quantization scheme to be used.
-
quant_min – Min quantization value. Will follow dtype if unspecified.
-
quant_max – Max quantization value. Will follow dtype if unspecified.
-
is_sync_quantize – If sync statistics when training with multiple devices.
-
factory_kwargs – kwargs which are passed to factory functions for min_val and max_val.
forward(x_orig)
Record the running minimum and maximum of x.
class horizon_plugin_pytorch.quantization.observer_v2.MixObserver(averaging_constant: float = 0.01, ch_axis: int = - 1, dtype: Union[torch.dtype, horizon_plugin_pytorch.dtype.QuantDType] = 'qint8', qscheme: torch.qscheme = torch.per_tensor_symmetric, quant_min: int = None, quant_max: int = None, is_sync_quantize: bool = False, factory_kwargs: Dict = None)
Mix observer.
This observer computes the quantization parameters based on multiple calibration methods and selects the quantization parameters with the smallest quantization error.
参数
-
averaging_constant – Averaging constant for min/max.
-
ch_axis – Channel axis.
-
dtype – Quantized data type.
-
qscheme – Quantization scheme to be used.
-
quant_min – Min quantization value. Will follow dtype if unspecified.
-
quant_max – Max quantization value. Will follow dtype if unspecified.
-
is_sync_quantize – If sync statistics when training with multiple devices.
-
factory_kwargs – kwargs which are passed to factory functions for min_val and max_val.
forward(x_orig)
Defines the computation performed at every call.
Should be overridden by all subclasses.
Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
class horizon_plugin_pytorch.quantization.observer_v2.PercentileObserver(percentile: float = 99.99, bins: int = 2048, averaging_constant: float = 0.01, ch_axis: int = - 1, dtype: Union[torch.dtype, horizon_plugin_pytorch.dtype.QuantDType] = 'qint8', qscheme: torch.qscheme = torch.per_tensor_symmetric, quant_min: int = None, quant_max: int = None, is_sync_quantize: bool = False, factory_kwargs: Dict = None)
Percentile observer.
Percentile observer based on histogram. Histogram is calculated online and won’t be saved. The minimum and maximum are moving averaged to compute the quantization parameters.
参数
-
percentile – Index percentile of histrogram
-
bins – Number of histograms bins.
-
averaging_constant – Averaging constant for min/max.
-
ch_axis – Channel axis.
-
dtype – Quantized data type.
-
qscheme – Quantization scheme to be used.
-
quant_min – Min quantization value. Will follow dtype if unspecified.
-
quant_max – Max quantization value. Will follow dtype if unspecified.
-
is_sync_quantize – If sync statistics when training with multiple devices.
-
factory_kwargs – kwargs which are passed to factory functions for min_val and max_val.
forward(x_orig)
Defines the computation performed at every call.
Should be overridden by all subclasses.
Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
class horizon_plugin_pytorch.quantization.MovingAverageMinMaxObserver(averaging_constant=0.01, dtype=torch.qint8, qscheme=torch.per_tensor_symmetric, quant_min=None, quant_max=None, is_sync_quantize=False, factory_kwargs=None)
MovingAverageMinMax Observer.
Observer module for computing the quantization parameters based on the moving average of the min and max values.
This observer computes the quantization parameters based on the moving averages of minimums and maximums of the incoming tensors. The module records the average minimum and maximum of incoming tensors, and uses this statistic to compute the quantization parameters.
参数
-
averaging_constant – Averaging constant for min/max.
-
dtype – Quantized data type
-
qscheme – Quantization scheme to be used, only support per_tensor_symmetric scheme
-
reduce_range – Reduces the range of the quantized data type by 1 bit
-
quant_min – Minimum quantization value.
-
quant_max – Maximum quantization value.
-
is_sync_quantize – Whether use sync quantize
-
factory_kwargs – Arguments for register data buffer
forward(x_orig)
Record the running minimum and maximum of x.
class horizon_plugin_pytorch.quantization.MovingAveragePerChannelMinMaxObserver(averaging_constant=0.01, ch_axis=0, dtype=torch.qint8, qscheme=torch.per_channel_symmetric, quant_min=None, quant_max=None, is_sync_quantize=False, factory_kwargs=None)
MovingAveragePerChannelMinMax Observer.
Observer module for computing the quantization parameters based on the running per channel min and max values.
This observer uses the tensor min/max statistics to compute the per channel quantization parameters. The module records the running minimum and maximum of incoming tensors, and uses this statistic to compute the quantization parameters.
参数
-
averaging_constant – Averaging constant for min/max.
-
ch_axis – Channel axis
-
dtype – Quantized data type
-
qscheme – Quantization scheme to be used, Only support per_channel_symmetric
-
quant_min – Minimum quantization value.
-
quant_max – Maximum quantization value.
-
is_sync_quantize – whether use sync quantize
-
factory_kwargs – Arguments for register data buffer
forward(x_orig)
Defines the computation performed at every call.
Should be overridden by all subclasses.
Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
量化训练指南
量化训练通过在模型中插入一些伪量化节点,从而使得通过量化训练得到的模型转换成定点模型时尽可能减少精度损失。 量化训练和传统的模型训练无异,开发者可以从零开始,搭建一个伪量化模型,然后对该伪量化模型进行训练。 由于部署的硬件平台有诸多限制,对于开发者来说,搞清这些限制,并且根据这些限制搭建伪量化模型门槛较高。量化训练工具通过在开发者提供的浮点模型上根据部署平台的限制自动插入伪量化量化算子的方法,降低开发者开发量化模型的门槛。
量化训练由于施加了各种限制,因此,一般来说,量化训练比纯浮点模型的训练更加困难。量化训练工具的目标是降低量化训练的难度,降低量化模型部署的工程难度。
流程和示例
虽然量化训练工具不强制要求用户从一个预训练的浮点模型开始,但是,经验表明,通常从预训练的高精度浮点模型开始量化训练能大大降低量化训练的难度。
from horizon_plugin_pytorch.quantization import get_default_qconfig
# 将模型转为 QAT 状态
default_qat_8bit_fake_quant_qconfig = get_default_qconfig(
activation_fake_quant="fake_quant",
weight_fake_quant="fake_quant",
activation_observer="min_max",
weight_observer="min_max",
activation_qkwargs=None,
weight_qkwargs={
"qscheme": torch.per_channel_symmetric,
"ch_axis": 0,
},
)
default_qat_out_8bit_fake_quant_qconfig = get_default_qconfig(
activation_fake_quant=None,
weight_fake_quant="fake_quant",
activation_observer=None,
weight_observer="min_max",
activation_qkwargs=None,
weight_qkwargs={
"qscheme": torch.per_channel_symmetric,
"ch_axis": 0,
},
)
qat_model = prepare_qat_fx(
float_model,
{
"": default_qat_8bit_fake_quant_qconfig,
"module_name": {
"classifier": default_qat_out_8bit_fake_quant_qconfig,
},
},
).to(device)
# 加载 Calibration 模型中的量化参数
qat_model.load_state_dict(calib_model.state_dict())
# 进行量化感知训练
# 作为一个 filetune 过程,量化感知训练一般需要设定较小的学习率
optimizer = torch.optim.SGD(
qat_model.parameters(), lr=0.0001, weight_decay=2e-4
)
for nepoch in range(epoch_num):
# 注意此处对 QAT 模型 training 状态的控制方法
qat_model.train()
set_fake_quantize(qat_model, FakeQuantState.QAT)
train_one_epoch(
qat_model,
nn.CrossEntropyLoss(),
optimizer,
None,
train_data_loader,
device,
)
# 注意此处对 QAT 模型 eval 状态的控制方法
qat_model.eval()
set_fake_quantize(qat_model, FakeQuantState.VALIDATION)
# 测试 qat 模型精度
top1, top5 = evaluate(
qat_model,
eval_data_loader,
device,
)
print(
"QAT model: evaluation Acc@1 {:.3f} Acc@5 {:.3f}".format(
top1.avg, top5.avg
)
)
# 测试 quantized 模型精度
quantized_model = convert_fx(qat_model.eval()).to(device)
top1, top5 = evaluate(
quantized_model,
eval_data_loader,
device,
)
print(
"Quantized model: evaluation Acc@1 {:.3f} Acc@5 {:.3f}".format(
top1.avg, top5.avg
)
)
由于部署平台的底层限制,QAT 模型无法完全代表最终上板精度,请务必监控 quantized 模型精度,确保 quantized 模型精度正常,否则可能出现模型上板掉点问题。
由上述示例代码可以看到,与传统的纯浮点模型训练相比,量化训练多了两个步骤:
- prepare_qat_fx
- 加载 Calibration 模型参数
prepare_qat_fx
这一步骤的目标是对浮点网络进行变换,插入伪量化节点。
加载 Calibration 模型参数
通过加载 Calibration 得到的伪量化参数,来获得一个较好的初始化。
训练迭代
至此,完成了伪量化模型的搭建和参数的初始化,然后就可以进行常规的训练迭代和模型参数更新,并且监控 quantized 模型精度。