summaryrefslogtreecommitdiff
path: root/awq-py/awq/apply_awq.py
blob: 11132c5d26e0c3e106e07958704e4377e4bc2079 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
"""
Implements the AWQ for llama.cpp use cases.
Original paper: https://arxiv.org/abs/2306.00978

This code is based on versions of the AWQ implementation found in the following repositories:
* https://github.com/mit-han-lab/llm-awq
* https://github.com/casper-hansen/AutoAWQ
"""

import os
import torch
import torch.nn as nn

from transformers import AutoModelForCausalLM, AutoConfig
from transformers.models.bloom.modeling_bloom import BloomGelu
from transformers.models.llama.modeling_llama import LlamaRMSNorm
from transformers.activations import GELUActivation


class ScaledActivation(nn.Module):
    """
    ScaledActivation module wraps an existing activation function and applies a
    scale factor to its output.

    Args:
        module (nn.Module): The activation function to be scaled.
        scales (torch.Tensor): A tensor of size (num_features,) containing the initial
            scale factors for each feature.

    Returns:
        torch.Tensor: The scaled output of the activation function.
    """

    def __init__(self, module, scales):
        super().__init__()
        self.act = module
        self.scales = nn.Parameter(scales.data)

    def forward(self, x):
        return self.act(x) / self.scales.view(1, 1, -1).to(x.device)


def set_op_by_name(layer, name, new_module):
    """
    Set the new module for given module's name.

    Args:
        layer (nn.Module): The layer in which to replace the submodule.
        name (str): The path to the submodule to be replaced, using dot notation
            to access nested modules.
        new_module (nn.Module): The new module to replace the existing one.
    """
    levels = name.split(".")
    if len(levels) > 1:
        mod_ = layer
        for l_idx in range(len(levels) - 1):
            if levels[l_idx].isdigit():
                mod_ = mod_[int(levels[l_idx])]
            else:
                mod_ = getattr(mod_, levels[l_idx])
        setattr(mod_, levels[-1], new_module)
    else:
        setattr(layer, name, new_module)


def get_op_by_name(module, op_name):
    """
    Retrieves a submodule within a given layer based on its name.

    Args:
        module (nn.Module): The layer containing the submodule to find.
        op_name (str): The name of the submodule.

    Returns:
        nn.Module: The requested submodule found within the given layer.

    Raises:
        ValueError: If the specified submodule cannot be found within the layer.
    """
    for name, m in module.named_modules():
        if name == op_name:
            return m
    raise ValueError(f"Cannot find op {op_name} in module {module}")


@torch.no_grad()
def scale_ln_fcs(ln, fcs, scales):
    """
    Scales the weights of a LayerNorm and a list of fully-connected layers proportionally.

    Args:
        ln (nn.LayerNorm): The LayerNorm module to be scaled.
        fcs (List[nn.Linear]): A list of fully-connected layers to be scaled.
        scales (torch.Tensor): A 1D tensor of size (num_features,).
    """

    if not isinstance(fcs, list):
        fcs = [fcs]

    scales = scales.to(ln.weight.device)

    ln.weight.div_(scales)
    if hasattr(ln, "bias") and ln.bias is not None:
        ln.bias.div_(scales)

    for fc in fcs:
        fc.weight.mul_(scales.view(1, -1))

    for p in ln.parameters():
        assert torch.isnan(p).sum() == 0
    for fc in fcs:
        for p in fc.parameters():
            assert torch.isnan(p).sum() == 0


@torch.no_grad()
def scale_fc_fc(fc1, fc2, scales):
    """
    Scales the weights of two fully-connected layers in a specific pattern.

    Args:
        fc1 (nn.Linear): The first fully-connected layer to be scaled.
        fc2 (nn.Linear): The second fully-connected layer to be scaled.
        scales (torch.Tensor): A 1D tensor of size (num_features,).
    """
    assert isinstance(fc1, nn.Linear)
    assert isinstance(fc2, nn.Linear)

    scales = scales.to(fc1.weight.device)

    fc1.weight[-scales.size(0):].div_(scales.view(-1, 1))
    if fc1.bias is not None:
        fc1.bias.div_(scales.view(-1))

    fc2.weight.mul_(scales.view(1, -1))

    for p in fc1.parameters():
        assert torch.isnan(p).sum() == 0
    for p in fc2.parameters():
        assert torch.isnan(p).sum() == 0


@torch.no_grad()
def scale_gelu_fc(gelu, fc, scales):
    """
    Scales the weight of a GELU activation and a fully-connected layer proportionally.

    Args:
        gelu (Union[nn.GELU, BloomGelu, GELUActivation]): The GELU activation module to be scaled.
        fc (nn.Linear): The fully-connected layer to be scaled.
        scales (torch.Tensor): A 1D tensor of size (num_features,).

    Raises:
        TypeError: If the `gelu` module is not of type `nn.GELU`, `BloomGelu`, or `GELUActivation`.
        TypeError: If the `fc` module is not of type `nn.Linear`.
    """
    assert isinstance(gelu, (nn.GELU, BloomGelu, GELUActivation))
    assert isinstance(fc, nn.Linear)

    fc.weight.mul_(scales.view(1, -1).to(fc.weight.device))

    for p in fc.parameters():
        assert torch.isnan(p).sum() == 0


def apply_scale(module, scales_list, input_feat_dict=None):
    """
    Applies different scaling strategies to layers based on their type and hierarchy within a given module.

    Args:
        module (nn.Module): The module containing the layers to be scaled.
        scales_list (List[Tuple[str, List[str], torch.Tensor]]): A list of tuples containing:
            * prev_op_name (str): The name of the preceding operation or module,
                relative to which the layers to be scaled are located.
            * layer_names (List[str]): A list of names of the layers to be scaled, relative to the preceding operation.
            * scales (torch.Tensor): A 1D tensor of size (num_features,) containing the scaling factors for each feature.
        input_feat_dict (Optional[Dict[str, torch.Tensor]]): A dictionary mapping layer names to their corresponding
            input features (optional).
    """
    for prev_op_name, layer_names, scales in scales_list:
        prev_op = get_op_by_name(module, prev_op_name)
        layers = [get_op_by_name(module, name) for name in layer_names]

        prev_op.cuda()
        for layer in layers:
            layer.cuda()
        scales.cuda()

        if isinstance(prev_op, nn.Linear):
            assert len(layers) == 1
            scale_fc_fc(prev_op, layers[0], scales)
        elif isinstance(prev_op, (nn.LayerNorm, LlamaRMSNorm)) or "rmsnorm" in str(prev_op.__class__).lower():
            scale_ln_fcs(prev_op, layers, scales)
        elif isinstance(prev_op, (nn.GELU, BloomGelu, GELUActivation)):
            new_module = ScaledActivation(prev_op, scales)
            set_op_by_name(module, prev_op_name, new_module)
            scale_gelu_fc(prev_op, layers[0], scales)
        else:
            raise NotImplementedError(f"prev_op {type(prev_op)} not supported yet!")

        # apply the scaling to input feat if given; prepare it for clipping
        if input_feat_dict is not None:
            for layer_name in layer_names:
                inp = input_feat_dict[layer_name]
                inp.div_(scales.view(1, -1).to(inp.device))

        prev_op.cpu()
        for layer in layers:
            layer.cpu()
        scales.cpu()


@torch.no_grad()
def apply_clip(module, clip_list):
    """
    Applies element-wise clipping to the weight of a specific layer within a given module.

    Args:
        module (nn.Module): The module containing the layer to be clipped.
        clip_list (List[Tuple[str, torch.Tensor]]): A list of tuples containing:
            * name (str): The name of the layer to be clipped, relative to the root of the module.
            * max_val (torch.Tensor): A 1D or 2D tensor defining the upper bound for each element of the layer's weight.
    """
    for name, max_val in clip_list:
        layer = get_op_by_name(module, name)
        layer.cuda()
        max_val = max_val.to(layer.weight.device)
        org_shape = layer.weight.shape
        layer.weight.data = layer.weight.data.reshape(*max_val.shape[:2], -1)
        layer.weight.data = torch.clamp(layer.weight.data, -max_val, max_val)
        layer.weight.data = layer.weight.data.reshape(org_shape)
        layer.cpu()


def add_scale_weights(model_path, scale_path, tmp_path):
    """
    Adds pre-computed Activation Weight Quantization (AWQ) results to a model,
    including scaling factors and clipping bounds.

    Args:
        model_path (str): Path to the pre-trained model to be equipped with AWQ.
        scale_path (str): Path to the AWQ scale factors (.pt file).
        tmp_path (str): Path to the temporary directory where the equipped model will be saved.
    """
    config = AutoConfig.from_pretrained(model_path, trust_remote_code=True)
    model = AutoModelForCausalLM.from_pretrained(
        model_path, config=config, trust_remote_code=True
    )
    model.eval()
    awq_results = torch.load(str(scale_path), map_location="cpu")
    apply_scale(model, awq_results["scale"])
    apply_clip(model, awq_results["clip"])
    model.save_pretrained(str(tmp_path))
    os.system(f"cp {str(model_path)}/tokenizer* {str(tmp_path)}")