KnutJaegersberg commited on
Commit
b3c0032
1 Parent(s): ac691e9

Upload 91 files

Browse files
Files changed (44) hide show
  1. quip-sharp/README.md +3 -4
  2. quip-sharp/docs/index.html +51 -51
  3. quip-sharp/docs/index.md +8 -8
  4. quip-sharp/hfize_llama.py +27 -78
  5. quip-sharp/lib/__pycache__/__init__.cpython-310.pyc +0 -0
  6. quip-sharp/lib/codebook/__pycache__/__init__.cpython-310.pyc +0 -0
  7. quip-sharp/lib/codebook/__pycache__/half_integer_4bit_1col.cpython-310.pyc +0 -0
  8. quip-sharp/lib/codebook/__pycache__/latticed4.cpython-310.pyc +0 -0
  9. quip-sharp/lib/codebook/__pycache__/latticee8_padded12.cpython-310.pyc +0 -0
  10. quip-sharp/lib/codebook/latticee8_padded12.py +100 -129
  11. quip-sharp/lib/linear/__pycache__/__init__.cpython-310.pyc +0 -0
  12. quip-sharp/lib/linear/__pycache__/fused_quantized_linear.cpython-310.pyc +0 -0
  13. quip-sharp/lib/linear/__pycache__/quantized_linear.cpython-310.pyc +0 -0
  14. quip-sharp/lib/linear/fused_quantized_linear.py +22 -0
  15. quip-sharp/lib/linear/quantized_linear.py +25 -16
  16. quip-sharp/lib/utils/__pycache__/__init__.cpython-310.pyc +0 -0
  17. quip-sharp/lib/utils/__pycache__/data_utils.cpython-310.pyc +0 -0
  18. quip-sharp/lib/utils/__pycache__/lm_eval_adaptor.cpython-310.pyc +0 -0
  19. quip-sharp/lib/utils/__pycache__/math_utils.cpython-310.pyc +0 -0
  20. quip-sharp/lib/utils/__pycache__/matmul_had.cpython-310.pyc +0 -0
  21. quip-sharp/lib/utils/__pycache__/matmul_kron.cpython-310.pyc +0 -0
  22. quip-sharp/lib/utils/__pycache__/misc.cpython-310.pyc +0 -0
  23. quip-sharp/lib/utils/__pycache__/unsafe_import.cpython-310.pyc +0 -0
  24. quip-sharp/lib/utils/data_utils.py +1 -1
  25. quip-sharp/lib/utils/unsafe_import.py +1 -3
  26. quip-sharp/model/__pycache__/graph_wrapper.cpython-310.pyc +0 -0
  27. quip-sharp/model/__pycache__/llama.cpython-310.pyc +0 -0
  28. quip-sharp/model/__pycache__/mistral.cpython-310.pyc +0 -0
  29. quip-sharp/model/__pycache__/version.cpython-310.pyc +0 -0
  30. quip-sharp/model/llama.py +25 -55
  31. quip-sharp/model/mistral.py +23 -54
  32. quip-sharp/model/version.py +2 -2
  33. quip-sharp/quantize_llama.py +3 -3
  34. quip-sharp/quiptools/build/lib.linux-x86_64-cpython-310/quiptools_cuda.cpython-310-x86_64-linux-gnu.so +2 -2
  35. quip-sharp/quiptools/build/temp.linux-x86_64-cpython-310/.ninja_deps +0 -0
  36. quip-sharp/quiptools/build/temp.linux-x86_64-cpython-310/.ninja_log +3 -5
  37. quip-sharp/quiptools/build/temp.linux-x86_64-cpython-310/quiptools.o +1 -1
  38. quip-sharp/quiptools/build/temp.linux-x86_64-cpython-310/quiptools_e8p_gemv.o +2 -2
  39. quip-sharp/quiptools/build/temp.linux-x86_64-cpython-310/quiptools_wrapper.o +2 -2
  40. quip-sharp/quiptools/dist/quiptools_cuda-0.0.0-py3.10-linux-x86_64.egg +2 -2
  41. quip-sharp/quiptools/quiptools_cuda.egg-info/SOURCES.txt +0 -5
  42. quip-sharp/quiptools/quiptools_e8p_gemv.cu +501 -227
  43. quip-sharp/quiptools/quiptools_wrapper.cpp +8 -3
  44. quip-sharp/scripts/upload_hf.py +1 -0
quip-sharp/README.md CHANGED
@@ -10,7 +10,7 @@ We also provide a full codebase that allows users to quantize and deploy their o
10
  | OPTQ | 3 bit | 4.577 | 6.838 | 0.544 | **0.786** |
11
  | OPTQ | 2 bit | 109.820 | 62.692 | 0.253 | 0.505 |
12
  | QuIP | 2 bit | 5.574 | 8.268 | 0.544 | 0.751 |
13
- | **QuIP#** | **2 bit** | **4.156** | **6.545** | **0.595** | 0.785 |
14
 
15
  Quantization results on Llama 2 70B. QuIP# achieves near-native performance at 2 bits, outperforming all other presented baselines.
16
 
@@ -18,9 +18,8 @@ Quantization results on Llama 2 70B. QuIP# achieves near-native performance at 2
18
 
19
  ## News
20
 
21
- - We have "deprecated" the 2 bit D4 quantized models as they perform worse than 2 bit E8P models and are slower to run. The code to quantize and run D4 models is still in the codebase, but the D4 models have been removed from HF and we are no longer actively supporting them.
22
- - We recently added 2 and 4 bit quantized versions of [Mistral 7B](https://huggingface.co/mistralai/Mistral-7B-v0.1) and [OpenHermes 2.5](https://huggingface.co/teknium/OpenHermes-2.5-Mistral-7B). See the Model Zoo section for more details.
23
- - **The 4 bit models have been replaced by new bit-packed models that end with the `-Packed` suffix. The old models have been deprecated, removed, and do not work with the current code (and vice versa). Make sure to pull the latest code to run the 4 bit models.**
24
 
25
  ## Installation
26
 
 
10
  | OPTQ | 3 bit | 4.577 | 6.838 | 0.544 | **0.786** |
11
  | OPTQ | 2 bit | 109.820 | 62.692 | 0.253 | 0.505 |
12
  | QuIP | 2 bit | 5.574 | 8.268 | 0.544 | 0.751 |
13
+ | **QuIP#** | **2 bit** | **4.159** | **6.529** | **0.595** | **0.786** |
14
 
15
  Quantization results on Llama 2 70B. QuIP# achieves near-native performance at 2 bits, outperforming all other presented baselines.
16
 
 
18
 
19
  ## News
20
 
21
+ - We merged in a faster E8P kernel that (with CUDA graphs) is around twice as fast as before. Make sure to pull the latest code and models and recompile `quiptools` to get the faster kernel. As a reminder, `hf.generate()` does not work with CUDA graphs so the generation speed in `interactive_gen.py` is not representative of reality.
22
+ - We fixed a duplicated entry in the E8P codebook and updated the result tables.
 
23
 
24
  ## Installation
25
 
quip-sharp/docs/index.html CHANGED
@@ -283,10 +283,10 @@ class="math inline">\(\uparrow\)</span></th>
283
  <tr class="odd">
284
  <td style="text-align: center;"><strong>QuIP#</strong></td>
285
  <td style="text-align: center;"><strong>2 bit</strong></td>
286
- <td style="text-align: center;"><strong>4.156</strong></td>
287
- <td style="text-align: center;"><strong>6.545</strong></td>
288
  <td style="text-align: center;"><strong>0.595</strong></td>
289
- <td style="text-align: center;">0.785</td>
290
  </tr>
291
  </tbody>
292
  </table>
@@ -688,13 +688,13 @@ class="math inline">\(\uparrow\)</span></th>
688
  <tr class="even">
689
  <td style="text-align: center;">2-70B</td>
690
  <td style="text-align: center;">QuIP#</td>
691
- <td style="text-align: center;">6.535</td>
692
- <td style="text-align: center;">4.156</td>
693
- <td style="text-align: center;">0.469</td>
694
  <td style="text-align: center;">0.595</td>
695
- <td style="text-align: center;">0.795</td>
696
- <td style="text-align: center;">0.785</td>
697
- <td style="text-align: center;">0.740</td>
698
  </tr>
699
  <tr class="odd">
700
  <td style="text-align: center;">2-13B</td>
@@ -710,13 +710,13 @@ class="math inline">\(\uparrow\)</span></th>
710
  <tr class="even">
711
  <td style="text-align: center;">2-13B</td>
712
  <td style="text-align: center;">QuIP#</td>
713
- <td style="text-align: center;">8.769</td>
714
- <td style="text-align: center;">6.003</td>
715
- <td style="text-align: center;">0.381</td>
716
- <td style="text-align: center;">0.502</td>
717
- <td style="text-align: center;">0.643</td>
718
- <td style="text-align: center;">0.751</td>
719
- <td style="text-align: center;">0.637</td>
720
  </tr>
721
  <tr class="odd">
722
  <td style="text-align: center;">2-7B</td>
@@ -732,13 +732,13 @@ class="math inline">\(\uparrow\)</span></th>
732
  <tr class="even">
733
  <td style="text-align: center;">2-7B</td>
734
  <td style="text-align: center;">QuIP#</td>
735
- <td style="text-align: center;">12.208</td>
736
- <td style="text-align: center;">8.201</td>
737
- <td style="text-align: center;">0.346</td>
738
- <td style="text-align: center;">0.454</td>
739
- <td style="text-align: center;">0.647</td>
740
- <td style="text-align: center;">0.726</td>
741
- <td style="text-align: center;">0.618</td>
742
  </tr>
743
  <tr class="odd">
744
  <td style="text-align: center;">1-65b</td>
@@ -754,13 +754,13 @@ class="math inline">\(\uparrow\)</span></th>
754
  <tr class="even">
755
  <td style="text-align: center;">1-65b</td>
756
  <td style="text-align: center;">QuIP#</td>
757
- <td style="text-align: center;">6.749</td>
758
- <td style="text-align: center;">4.573</td>
759
- <td style="text-align: center;">0.435</td>
760
- <td style="text-align: center;">0.566</td>
761
- <td style="text-align: center;">0.831</td>
762
- <td style="text-align: center;">0.792</td>
763
- <td style="text-align: center;">0.756</td>
764
  </tr>
765
  <tr class="odd">
766
  <td style="text-align: center;">1-30B</td>
@@ -776,13 +776,13 @@ class="math inline">\(\uparrow\)</span></th>
776
  <tr class="even">
777
  <td style="text-align: center;">1-30B</td>
778
  <td style="text-align: center;">QuIP#</td>
779
- <td style="text-align: center;">7.465</td>
780
- <td style="text-align: center;">5.311</td>
781
- <td style="text-align: center;">0.422</td>
782
- <td style="text-align: center;">0.537</td>
783
- <td style="text-align: center;">0.659</td>
784
- <td style="text-align: center;">0.776</td>
785
- <td style="text-align: center;">0.714</td>
786
  </tr>
787
  <tr class="odd">
788
  <td style="text-align: center;">1-13B</td>
@@ -798,13 +798,13 @@ class="math inline">\(\uparrow\)</span></th>
798
  <tr class="even">
799
  <td style="text-align: center;">1-13B</td>
800
  <td style="text-align: center;">QuIP#</td>
801
- <td style="text-align: center;">8.426</td>
802
- <td style="text-align: center;">6.353</td>
803
- <td style="text-align: center;">0.382</td>
804
- <td style="text-align: center;">0.537</td>
805
- <td style="text-align: center;">0.665</td>
806
- <td style="text-align: center;">0.757</td>
807
- <td style="text-align: center;">0.687</td>
808
  </tr>
809
  <tr class="odd">
810
  <td style="text-align: center;">1-7B</td>
@@ -820,13 +820,13 @@ class="math inline">\(\uparrow\)</span></th>
820
  <tr class="even">
821
  <td style="text-align: center;">1-7B</td>
822
  <td style="text-align: center;">QuIP#</td>
823
- <td style="text-align: center;">10.927</td>
824
- <td style="text-align: center;">8.146</td>
825
- <td style="text-align: center;">0.347</td>
826
- <td style="text-align: center;">0.471</td>
827
- <td style="text-align: center;">0.673</td>
828
- <td style="text-align: center;">0.724</td>
829
- <td style="text-align: center;">0.621</td>
830
  </tr>
831
  </tbody>
832
  </table>
 
283
  <tr class="odd">
284
  <td style="text-align: center;"><strong>QuIP#</strong></td>
285
  <td style="text-align: center;"><strong>2 bit</strong></td>
286
+ <td style="text-align: center;"><strong>4.159</strong></td>
287
+ <td style="text-align: center;"><strong>6.529</strong></td>
288
  <td style="text-align: center;"><strong>0.595</strong></td>
289
+ <td style="text-align: center;">0.786</td>
290
  </tr>
291
  </tbody>
292
  </table>
 
688
  <tr class="even">
689
  <td style="text-align: center;">2-70B</td>
690
  <td style="text-align: center;">QuIP#</td>
691
+ <td style="text-align: center;">6.529</td>
692
+ <td style="text-align: center;">4.158</td>
693
+ <td style="text-align: center;">0.472</td>
694
  <td style="text-align: center;">0.595</td>
695
+ <td style="text-align: center;">0.791</td>
696
+ <td style="text-align: center;">0.786</td>
697
+ <td style="text-align: center;">0.742</td>
698
  </tr>
699
  <tr class="odd">
700
  <td style="text-align: center;">2-13B</td>
 
710
  <tr class="even">
711
  <td style="text-align: center;">2-13B</td>
712
  <td style="text-align: center;">QuIP#</td>
713
+ <td style="text-align: center;">8.755</td>
714
+ <td style="text-align: center;">6.058</td>
715
+ <td style="text-align: center;">0.371</td>
716
+ <td style="text-align: center;">0.501</td>
717
+ <td style="text-align: center;">0.665</td>
718
+ <td style="text-align: center;">0.757</td>
719
+ <td style="text-align: center;">0.636</td>
720
  </tr>
721
  <tr class="odd">
722
  <td style="text-align: center;">2-7B</td>
 
732
  <tr class="even">
733
  <td style="text-align: center;">2-7B</td>
734
  <td style="text-align: center;">QuIP#</td>
735
+ <td style="text-align: center;">12.062</td>
736
+ <td style="text-align: center;">8.224</td>
737
+ <td style="text-align: center;">0.325</td>
738
+ <td style="text-align: center;">0.428</td>
739
+ <td style="text-align: center;">0.623</td>
740
+ <td style="text-align: center;">0.712</td>
741
+ <td style="text-align: center;">0.624</td>
742
  </tr>
743
  <tr class="odd">
744
  <td style="text-align: center;">1-65b</td>
 
754
  <tr class="even">
755
  <td style="text-align: center;">1-65b</td>
756
  <td style="text-align: center;">QuIP#</td>
757
+ <td style="text-align: center;">6.744</td>
758
+ <td style="text-align: center;">4.566</td>
759
+ <td style="text-align: center;">0.436</td>
760
+ <td style="text-align: center;">0.569</td>
761
+ <td style="text-align: center;">0.817</td>
762
+ <td style="text-align: center;">0.805</td>
763
+ <td style="text-align: center;">0.736</td>
764
  </tr>
765
  <tr class="odd">
766
  <td style="text-align: center;">1-30B</td>
 
776
  <tr class="even">
777
  <td style="text-align: center;">1-30B</td>
778
  <td style="text-align: center;">QuIP#</td>
779
+ <td style="text-align: center;">7.471</td>
780
+ <td style="text-align: center;">5.317</td>
781
+ <td style="text-align: center;">0.429</td>
782
+ <td style="text-align: center;">0.545</td>
783
+ <td style="text-align: center;">0.669</td>
784
+ <td style="text-align: center;">0.779</td>
785
+ <td style="text-align: center;">0.718</td>
786
  </tr>
787
  <tr class="odd">
788
  <td style="text-align: center;">1-13B</td>
 
798
  <tr class="even">
799
  <td style="text-align: center;">1-13B</td>
800
  <td style="text-align: center;">QuIP#</td>
801
+ <td style="text-align: center;">8.425</td>
802
+ <td style="text-align: center;">6.381</td>
803
+ <td style="text-align: center;">0.387</td>
804
+ <td style="text-align: center;">0.536</td>
805
+ <td style="text-align: center;">0.647</td>
806
+ <td style="text-align: center;">0.750</td>
807
+ <td style="text-align: center;">0.669</td>
808
  </tr>
809
  <tr class="odd">
810
  <td style="text-align: center;">1-7B</td>
 
820
  <tr class="even">
821
  <td style="text-align: center;">1-7B</td>
822
  <td style="text-align: center;">QuIP#</td>
823
+ <td style="text-align: center;">10.970</td>
824
+ <td style="text-align: center;">8.286</td>
825
+ <td style="text-align: center;">0.352</td>
826
+ <td style="text-align: center;">0.464</td>
827
+ <td style="text-align: center;">0.647</td>
828
+ <td style="text-align: center;">0.720</td>
829
+ <td style="text-align: center;">0.624</td>
830
  </tr>
831
  </tbody>
832
  </table>
quip-sharp/docs/index.md CHANGED
@@ -52,7 +52,7 @@ These two methods allow QuIP# to significantly close the gap between 2 bit quant
52
  | OPTQ | 3 bit | 4.577 | 6.838 | 0.544 | **0.786** |
53
  | OPTQ | 2 bit | 109.820 | 62.692 | 0.253 | 0.505 |
54
  | QuIP | 2 bit | 5.574 | 8.268 | 0.544 | 0.751 |
55
- | **QuIP#** | **2 bit** | **4.156** | **6.545** | **0.595** | 0.785 |
56
 
57
  :Quantization results on Llama 2 70B. QuIP# achieves near-native performance at 2 bits, outperforming all other presented baselines.
58
 
@@ -237,18 +237,18 @@ Additional results are available [here](https://docs.google.com/spreadsheets/d/1
237
  | Model | Method | C4 $\downarrow$ | Wiki $\downarrow$ | ArcC $\uparrow$ | ArcE $\uparrow$ | BoolQ $\uparrow$ | PiQA $\uparrow$ | WinoGrande $\uparrow$ |
238
  |:---------:|:---------:|:---------------:|:-----------------:|:---------------:|:---------------:|:-------------------:|:---------------:|:-------------------------------:|
239
  | 2-70B | fp16 | 5.533 | 3.120 | 0.480 | 0.597 | 0.766 | 0.809 | 0.768 |
240
- | 2-70B | QuIP# | 6.535 | 4.156 | 0.469 | 0.595 | 0.795 | 0.785 | 0.740 |
241
  | 2-13B | fp16 | 6.520 | 4.574 | 0.443 | 0.580 | 0.690 | 0.790 | 0.699 |
242
- | 2-13B | QuIP# | 8.769 | 6.003 | 0.381 | 0.502 | 0.643 | 0.751 | 0.637 |
243
  | 2-7B | fp16 | 7.036 | 5.116 | 0.406 | 0.535 | 0.710 | 0.769 | 0.670 |
244
- | 2-7B | QuIP# | 12.208 | 8.201 | 0.346 | 0.454 | 0.647 | 0.726 | 0.618 |
245
  | 1-65b | fp16 | 5.811 | 3.532 | 0.463 | 0.588 | 0.823 | 0.809 | 0.771 |
246
- | 1-65b | QuIP# | 6.749 | 4.573 | 0.435 | 0.566 | 0.831 | 0.792 | 0.756 |
247
  | 1-30B | fp16 | 6.130 | 4.101 | 0.453 | 0.590 | 0.684 | 0.801 | 0.728 |
248
- | 1-30B | QuIP# | 7.465 | 5.311 | 0.422 | 0.537 | 0.659 | 0.776 | 0.714 |
249
  | 1-13B | fp16 | 6.798 | 5.091 | 0.444 | 0.599 | 0.684 | 0.792 | 0.701 |
250
- | 1-13B | QuIP# | 8.426 | 6.353 | 0.382 | 0.537 | 0.665 | 0.757 | 0.687 |
251
  | 1-7B | fp16 | 7.343 | 5.677 | 0.415 | 0.525 | 0.731 | 0.774 | 0.670 |
252
- | 1-7B | QuIP# | 10.927 | 8.146 | 0.347 | 0.471 | 0.673 | 0.724 | 0.621 |
253
  :QuIP# results across all Llama 1 and 2 models. QuIP# achieves near-native performance at 2 bits on language modeling (C4, Wiki) and zero shot (ArcC, ArcE, BoolQ, PiQA, WinoGrande) tasks.
254
  </div>
 
52
  | OPTQ | 3 bit | 4.577 | 6.838 | 0.544 | **0.786** |
53
  | OPTQ | 2 bit | 109.820 | 62.692 | 0.253 | 0.505 |
54
  | QuIP | 2 bit | 5.574 | 8.268 | 0.544 | 0.751 |
55
+ | **QuIP#** | **2 bit** | **4.159** | **6.529** | **0.595** | 0.786 |
56
 
57
  :Quantization results on Llama 2 70B. QuIP# achieves near-native performance at 2 bits, outperforming all other presented baselines.
58
 
 
237
  | Model | Method | C4 $\downarrow$ | Wiki $\downarrow$ | ArcC $\uparrow$ | ArcE $\uparrow$ | BoolQ $\uparrow$ | PiQA $\uparrow$ | WinoGrande $\uparrow$ |
238
  |:---------:|:---------:|:---------------:|:-----------------:|:---------------:|:---------------:|:-------------------:|:---------------:|:-------------------------------:|
239
  | 2-70B | fp16 | 5.533 | 3.120 | 0.480 | 0.597 | 0.766 | 0.809 | 0.768 |
240
+ | 2-70B | QuIP# | 6.529 | 4.158 | 0.472 | 0.595 | 0.791 | 0.786 | 0.742 |
241
  | 2-13B | fp16 | 6.520 | 4.574 | 0.443 | 0.580 | 0.690 | 0.790 | 0.699 |
242
+ | 2-13B | QuIP# | 8.755 | 6.058 | 0.371 | 0.501 | 0.665 | 0.757 | 0.636 |
243
  | 2-7B | fp16 | 7.036 | 5.116 | 0.406 | 0.535 | 0.710 | 0.769 | 0.670 |
244
+ | 2-7B | QuIP# | 12.062 | 8.224 | 0.325 | 0.428 | 0.623 | 0.712 | 0.624 |
245
  | 1-65b | fp16 | 5.811 | 3.532 | 0.463 | 0.588 | 0.823 | 0.809 | 0.771 |
246
+ | 1-65b | QuIP# | 6.744 | 4.566 | 0.436 | 0.569 | 0.817 | 0.805 | 0.736 |
247
  | 1-30B | fp16 | 6.130 | 4.101 | 0.453 | 0.590 | 0.684 | 0.801 | 0.728 |
248
+ | 1-30B | QuIP# | 7.471 | 5.317 | 0.429 | 0.545 | 0.669 | 0.779 | 0.718 |
249
  | 1-13B | fp16 | 6.798 | 5.091 | 0.444 | 0.599 | 0.684 | 0.792 | 0.701 |
250
+ | 1-13B | QuIP# | 8.425 | 6.381 | 0.387 | 0.536 | 0.647 | 0.750 | 0.669 |
251
  | 1-7B | fp16 | 7.343 | 5.677 | 0.415 | 0.525 | 0.731 | 0.774 | 0.670 |
252
+ | 1-7B | QuIP# | 10.970 | 8.286 | 0.352 | 0.464 | 0.647 | 0.720 | 0.624 |
253
  :QuIP# results across all Llama 1 and 2 models. QuIP# achieves near-native performance at 2 bits on language modeling (C4, Wiki) and zero shot (ArcC, ArcE, BoolQ, PiQA, WinoGrande) tasks.
254
  </div>
quip-sharp/hfize_llama.py CHANGED
@@ -5,7 +5,6 @@ import torch
5
  from transformers import AutoTokenizer
6
  from model.version import MODEL_VERSION
7
  from model.llama import LlamaForCausalLM as llama_fuse
8
- from model.llama_nofuse import LlamaForCausalLM as llama_nofuse
9
  from model.mistral import MistralForCausalLM
10
  from lib import codebook
11
  from lib.utils.unsafe_import import model_from_hf_path
@@ -32,7 +31,6 @@ def unpack_quip(module, saved_layer, codebook_id, codesz):
32
  module.B.copy_(saved_layer['B'])
33
  module.SU.copy_(saved_layer['SU'])
34
  module.SV.copy_(saved_layer['SV'])
35
- module.Wscale.copy_(saved_layer['Wscale'])
36
  if module.rescale_WH:
37
  module.scaleWH.copy_(saved_layer['scaleWH'])
38
 
@@ -50,11 +48,10 @@ def main(args):
50
  tokenizer = AutoTokenizer.from_pretrained(model_config._name_or_path)
51
 
52
  model_type = model_config.model_type
53
- fused = model_config.quip_params.get('fused', True)
54
  model_config.quip_params['model_version'] = MODEL_VERSION
55
 
56
  if model_type == 'llama':
57
- model_cls = llama_fuse if fused else llama_nofuse
58
  elif model_type == 'mistral':
59
  model_cls = MistralForCausalLM
60
  else:
@@ -71,80 +68,32 @@ def main(args):
71
  layer = model.model.layers[ii]
72
  cpu = torch.device('cpu')
73
 
74
- if fused:
75
- glog.info(f'loading layer {ii} qkv')
76
- saved_layer = torch.load(f'{args.quantized_path}/{ii}_qkv.pt', map_location=cpu)
77
- layer.self_attn.q_scale.copy_(saved_layer['W_q_scale'])
78
- layer.self_attn.k_scale.copy_(saved_layer['W_k_scale'])
79
- layer.self_attn.v_scale.copy_(saved_layer['W_v_scale'])
80
- unpack_quip(layer.self_attn.qkv_proj, saved_layer, codebook_id, codesz)
81
-
82
- glog.info(f'loading layer {ii} up')
83
- saved_layer = torch.load(f'{args.quantized_path}/{ii}_up.pt', map_location=cpu)
84
- layer.mlp.up_scale.copy_(saved_layer['W_up_scale'])
85
- layer.mlp.gate_scale.copy_(saved_layer['W_gate_scale'])
86
- unpack_quip(layer.mlp.upgate_proj, saved_layer, codebook_id, codesz)
87
-
88
- glog.info(f'loading layer {ii} o')
89
- saved_layer = torch.load(f'{args.quantized_path}/{ii}_o.pt', map_location=cpu)
90
- layer.self_attn.o_scale.copy_(saved_layer['W_o_scale'])
91
- unpack_quip(layer.self_attn.o_proj, saved_layer, codebook_id, codesz)
92
-
93
- glog.info(f'loading layer {ii} down')
94
- saved_layer = torch.load(f'{args.quantized_path}/{ii}_down.pt', map_location=cpu)
95
- layer.mlp.down_scale.copy_(saved_layer['W_down_scale'])
96
-
97
- if model_config.quip_params['outlier_channel_split']:
98
- layer.mlp.down_proj.ocs_dupe_inds.copy_(torch.tensor(saved_layer['ocs_dupe_inds']))
99
-
100
- unpack_quip(layer.mlp.down_proj, saved_layer, codebook_id, codesz)
101
-
102
- else:
103
- saved_layer = torch.load(f'{args.quantized_path}/{ii}_q.pt', map_location=cpu)
104
- layer.self_attn.q_scale.copy_(saved_layer['W_scale'])
105
- if model_config.quip_params['outlier_channel_split']:
106
- layer.self_attn.q_proj.ocs_dupe_inds.copy_(
107
- torch.tensor(saved_layer['ocs_dupe_inds']))
108
- unpack_quip(layer.self_attn.q_proj, saved_layer, codebook_id, codesz)
109
-
110
- saved_layer = torch.load(f'{args.quantized_path}/{ii}_k.pt', map_location=cpu)
111
- layer.self_attn.k_scale.copy_(saved_layer['W_scale'])
112
- if model_config.quip_params['outlier_channel_split']:
113
- layer.self_attn.k_proj.ocs_dupe_inds.copy_(
114
- torch.tensor(saved_layer['ocs_dupe_inds']))
115
- unpack_quip(layer.self_attn.k_proj, saved_layer, codebook_id, codesz)
116
-
117
- saved_layer = torch.load(f'{args.quantized_path}/{ii}_v.pt', map_location=cpu)
118
- layer.self_attn.v_scale.copy_(saved_layer['W_scale'])
119
- if model_config.quip_params['outlier_channel_split']:
120
- layer.self_attn.v_proj.ocs_dupe_inds.copy_(
121
- torch.tensor(saved_layer['ocs_dupe_inds']))
122
- unpack_quip(layer.self_attn.v_proj, saved_layer, codebook_id, codesz)
123
-
124
- saved_layer = torch.load(f'{args.quantized_path}/{ii}_o.pt', map_location=cpu)
125
- layer.self_attn.o_scale.copy_(saved_layer['W_scale'])
126
- if model_config.quip_params['outlier_channel_split']:
127
- layer.self_attn.o_proj.ocs_dupe_inds.copy_(
128
- torch.tensor(saved_layer['ocs_dupe_inds']))
129
- unpack_quip(layer.self_attn.o_proj, saved_layer, codebook_id, codesz)
130
-
131
- saved_layer = torch.load(f'{args.quantized_path}/{ii}_up.pt', map_location=cpu)
132
- layer.mlp.up_scale.copy_(saved_layer['W_scale'])
133
- if model_config.quip_params['outlier_channel_split']:
134
- layer.mlp.up_proj.ocs_dupe_inds.copy_(torch.tensor(saved_layer['ocs_dupe_inds']))
135
- unpack_quip(layer.mlp.up_proj, saved_layer, codebook_id, codesz)
136
-
137
- saved_layer = torch.load(f'{args.quantized_path}/{ii}_gate.pt', map_location=cpu)
138
- layer.mlp.gate_scale.copy_(saved_layer['W_scale'])
139
- if model_config.quip_params['outlier_channel_split']:
140
- layer.mlp.gate_proj.ocs_dupe_inds.copy_(torch.tensor(saved_layer['ocs_dupe_inds']))
141
- unpack_quip(layer.mlp.gate_proj, saved_layer, codebook_id, codesz)
142
-
143
- saved_layer = torch.load(f'{args.quantized_path}/{ii}_down.pt', map_location=cpu)
144
- layer.mlp.down_scale.copy_(saved_layer['W_scale'])
145
- if model_config.quip_params['outlier_channel_split']:
146
- layer.mlp.down_proj.ocs_dupe_inds.copy_(torch.tensor(saved_layer['ocs_dupe_inds']))
147
- unpack_quip(layer.mlp.down_proj, saved_layer, codebook_id, codesz)
148
 
149
  glog.info(f'saving model...')
150
  model.save_pretrained(args.hf_output_path, safe_serialization=True)
 
5
  from transformers import AutoTokenizer
6
  from model.version import MODEL_VERSION
7
  from model.llama import LlamaForCausalLM as llama_fuse
 
8
  from model.mistral import MistralForCausalLM
9
  from lib import codebook
10
  from lib.utils.unsafe_import import model_from_hf_path
 
31
  module.B.copy_(saved_layer['B'])
32
  module.SU.copy_(saved_layer['SU'])
33
  module.SV.copy_(saved_layer['SV'])
 
34
  if module.rescale_WH:
35
  module.scaleWH.copy_(saved_layer['scaleWH'])
36
 
 
48
  tokenizer = AutoTokenizer.from_pretrained(model_config._name_or_path)
49
 
50
  model_type = model_config.model_type
 
51
  model_config.quip_params['model_version'] = MODEL_VERSION
52
 
53
  if model_type == 'llama':
54
+ model_cls = llama_fuse
55
  elif model_type == 'mistral':
56
  model_cls = MistralForCausalLM
57
  else:
 
68
  layer = model.model.layers[ii]
69
  cpu = torch.device('cpu')
70
 
71
+ glog.info(f'loading layer {ii} qkv')
72
+ saved_layer = torch.load(f'{args.quantized_path}/{ii}_qkv.pt', map_location=cpu)
73
+ layer.self_attn.qkv_proj.fuse_scales[0].copy_(saved_layer['W_q_scale'])
74
+ layer.self_attn.qkv_proj.fuse_scales[1].copy_(saved_layer['W_k_scale'])
75
+ layer.self_attn.qkv_proj.fuse_scales[2].copy_(saved_layer['W_v_scale'])
76
+ layer.self_attn.qkv_proj.Wscale.copy_(saved_layer['Wscale'])
77
+ unpack_quip(layer.self_attn.qkv_proj, saved_layer, codebook_id, codesz)
78
+
79
+ glog.info(f'loading layer {ii} up')
80
+ saved_layer = torch.load(f'{args.quantized_path}/{ii}_up.pt', map_location=cpu)
81
+ layer.mlp.upgate_proj.fuse_scales[0].copy_(saved_layer['W_up_scale'])
82
+ layer.mlp.upgate_proj.fuse_scales[1].copy_(saved_layer['W_gate_scale'])
83
+ layer.mlp.upgate_proj.Wscale.copy_(saved_layer['Wscale'])
84
+ unpack_quip(layer.mlp.upgate_proj, saved_layer, codebook_id, codesz)
85
+
86
+ glog.info(f'loading layer {ii} o')
87
+ saved_layer = torch.load(f'{args.quantized_path}/{ii}_o.pt', map_location=cpu)
88
+ layer.self_attn.o_proj.Wscale.copy_(saved_layer['W_o_scale'] * saved_layer['Wscale'])
89
+ unpack_quip(layer.self_attn.o_proj, saved_layer, codebook_id, codesz)
90
+
91
+ glog.info(f'loading layer {ii} down')
92
+ saved_layer = torch.load(f'{args.quantized_path}/{ii}_down.pt', map_location=cpu)
93
+ layer.mlp.down_proj.Wscale.copy_(saved_layer['W_down_scale'] * saved_layer['Wscale'])
94
+ if model_config.quip_params['outlier_channel_split']:
95
+ layer.mlp.down_proj.ocs_dupe_inds.copy_(torch.tensor(saved_layer['ocs_dupe_inds']))
96
+ unpack_quip(layer.mlp.down_proj, saved_layer, codebook_id, codesz)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
97
 
98
  glog.info(f'saving model...')
99
  model.save_pretrained(args.hf_output_path, safe_serialization=True)
quip-sharp/lib/__pycache__/__init__.cpython-310.pyc CHANGED
Binary files a/quip-sharp/lib/__pycache__/__init__.cpython-310.pyc and b/quip-sharp/lib/__pycache__/__init__.cpython-310.pyc differ
 
quip-sharp/lib/codebook/__pycache__/__init__.cpython-310.pyc CHANGED
Binary files a/quip-sharp/lib/codebook/__pycache__/__init__.cpython-310.pyc and b/quip-sharp/lib/codebook/__pycache__/__init__.cpython-310.pyc differ
 
quip-sharp/lib/codebook/__pycache__/half_integer_4bit_1col.cpython-310.pyc CHANGED
Binary files a/quip-sharp/lib/codebook/__pycache__/half_integer_4bit_1col.cpython-310.pyc and b/quip-sharp/lib/codebook/__pycache__/half_integer_4bit_1col.cpython-310.pyc differ
 
quip-sharp/lib/codebook/__pycache__/latticed4.cpython-310.pyc CHANGED
Binary files a/quip-sharp/lib/codebook/__pycache__/latticed4.cpython-310.pyc and b/quip-sharp/lib/codebook/__pycache__/latticed4.cpython-310.pyc differ
 
quip-sharp/lib/codebook/__pycache__/latticee8_padded12.cpython-310.pyc CHANGED
Binary files a/quip-sharp/lib/codebook/__pycache__/latticee8_padded12.cpython-310.pyc and b/quip-sharp/lib/codebook/__pycache__/latticee8_padded12.cpython-310.pyc differ
 
quip-sharp/lib/codebook/latticee8_padded12.py CHANGED
@@ -6,7 +6,6 @@ The total codebook is all 2^7 flips of these 256 entries (2^15) +- 1/4
6
  which makes 2^16 entries.
7
  This corresponds to a subset of E8 + 1/4
8
  """
9
-
10
  import torch
11
  import math
12
  from torch import nn
@@ -22,19 +21,12 @@ _INT_MAP = 2**(torch.arange(_E8P_CODESZ).flip(0))
22
  def int2mask(i, int_map):
23
  return ((i & int_map) > 0).int()
24
 
25
-
26
  def mask2int(mask, int_map):
27
  return (int_map.unsqueeze(0) * mask.int()).sum(dim=-1)
28
 
29
-
30
- def get_abs_grid():
31
- intr = torch.arange(-4, 4)
32
- d8 = torch.cartesian_prod(*[intr] * _E8P_CODESZ).float() + 1 / 2
33
- d8m2 = (d8.sum(dim=-1) % 2 == 0)
34
- d8n = d8.norm(dim=-1)**2 <= 10
35
- d8abs = torch.unique(d8[sorted(torch.where(d8m2 * d8n)[0])].abs(), dim=0)
36
-
37
- norm12 = torch.tensor([
38
  [3, 1, 1, 1, 3, 3, 3, 3],
39
  [1, 3, 1, 1, 3, 3, 3, 3],
40
  [1, 1, 3, 1, 3, 3, 3, 3],
@@ -62,82 +54,81 @@ def get_abs_grid():
62
  [1, 3, 3, 3, 1, 3, 3, 1],
63
  [1, 3, 3, 3, 3, 1, 1, 3],
64
  [1, 3, 3, 3, 1, 3, 1, 3],
65
- [1, 3, 3, 3, 1, 1, 3, 3],
66
  [3, 3, 1, 1, 3, 3, 3, 1],
67
  ]) / 2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
68
  return torch.concat([d8abs, norm12], dim=0)
69
 
70
 
71
- def get_full_grid(abs_grid):
72
- """
73
- idx format:
74
- - first 8 bits = which of the 256 entries in the abs grid
75
- - next 7 bits = which of the right 7 dims to negate (8th can be inferred)
76
- - last bit = +1/4 if true else -1/4
77
- """
78
- is_even_flips = abs_grid.sum(dim=-1) % 2 == 0
79
- abs_idxs = torch.arange(len(abs_grid)) << _E8P_CODESZ
80
- entries = [[], []]
81
- idxs = [[], []]
82
- for i in range(2**(_E8P_CODESZ - 1)):
83
- mask = int2mask(i, _INT_MAP)
84
- mask_even = (mask.sum(dim=-1) % 2 == 0)
85
- mask = mask.unsqueeze(0).repeat(len(abs_grid), 1)
86
- mask[:, 0] = mask_even != is_even_flips
87
- mask = 1 - 2 * mask
88
- entries[0].append(abs_grid * mask + 1 / 4)
89
- idxs[0].append(abs_idxs + (i << 1) + 1)
90
- entries[1].append(abs_grid * mask - 1 / 4)
91
- idxs[1].append(abs_idxs + (i << 1))
92
-
93
- for i in range(2):
94
- entries[i] = torch.concat(entries[i], dim=0)
95
- idxs[i] = torch.concat(idxs[i], dim=0)
96
- entries = torch.concat(entries, dim=0)
97
- idxs = torch.concat(idxs, dim=0)
98
- return entries, idxs
99
-
100
-
101
- _E8P_ABS_CACHED = get_abs_grid()
102
- _E8P_GRID, _E8P_GRID_IDX = get_full_grid(_E8P_ABS_CACHED)
103
 
104
 
105
  class E8P12_codebook(nn.Module):
106
 
107
  def __init__(self, inference=False):
108
  super(E8P12_codebook, self).__init__()
109
- self.opt_scale = 1 #.03#/1.09
110
  self.codesz = _E8P_CODESZ
111
- self.idx_dtype = torch.int16
112
- self.idx_offset = -2**15
113
- self.packsz = 1
114
  self.pack_out = False
115
- self.version = 0
116
 
117
- self.register_buffer('grid_abs', _E8P_ABS_CACHED)
118
- self.register_buffer('grid_abs_even', self.grid_abs.sum(dim=-1) % 2 == 0)
119
 
120
  if not inference:
121
- self.register_buffer('int_map', _INT_MAP)
122
  self.register_buffer('grid', _E8P_GRID)
123
- self.register_buffer('grid_idx_map',
124
- (_E8P_GRID_IDX + self.idx_offset).to(self.idx_dtype))
125
- idx_lut = torch.zeros(_E8P_GRID_IDX.shape).int()
126
- idx_lut[_E8P_GRID_IDX] = torch.arange(len(_E8P_GRID_IDX)).int()
127
- self.register_buffer('grid_idx_inv', idx_lut)
128
-
129
- self.register_buffer('grid_norm', torch.diag(self.grid @ self.grid.T))
130
- grid_part = self.grid[:len(self.grid) // 2] - 1 / 4
131
- idxs = torch.where(
132
- ((grid_part[:, 1:] < 0).sum(dim=-1) <= 1) * \
133
- (grid_part[:, 1:].min(dim=-1).values >= -0.5)
134
- )[0]
135
- grid_part = grid_part[idxs]
136
- self.register_buffer('grid_part', grid_part)
137
- self.register_buffer('grid_part_norm', torch.diag(grid_part @ grid_part.T))
138
- allcombo_idx, idx_map = self.iterate_mask()
139
- self.register_buffer('allcombo_idx', allcombo_idx)
140
- self.register_buffer('idx_map', idx_map)
141
  '''
142
  self.to('cuda')
143
  samples = torch.distributions.multivariate_normal.MultivariateNormal(torch.zeros(8), torch.eye(8)).rsample([2000000]).cuda()
@@ -146,60 +137,44 @@ class E8P12_codebook(nn.Module):
146
  exit()
147
  '''
148
 
149
- def iterate_mask(self, device=0):
150
- flips = torch.stack([((torch.tensor([i]) & self.int_map) > 0).int()
151
- for i in range(2**_E8P_CODESZ)]).to(device)
152
- raw_idx = torch.where(flips.sum(dim=-1) % 2 == 0)[0]
153
- flips = 1 - 2 * flips[raw_idx]
154
- idx_map = torch.zeros(2**_E8P_CODESZ, dtype=torch.int32)
155
- for i in range(len(raw_idx)):
156
- idx_map[raw_idx[i]] = i
157
- allcombo = flips.unsqueeze(1) * self.grid_part.unsqueeze(0).to(device)
158
- allcombo_idx = torch.zeros(allcombo.shape[0:2]).int()
159
- for i in range(len(allcombo)):
160
- allcombo_idx[i] = self.round(allcombo[i], self.grid.to(device),
161
- self.grid_norm.to(device))[1]
162
- return allcombo_idx.cpu(), idx_map.cpu()
163
-
164
  def round(self, X, grid, grid_norm):
165
  assert X.shape[-1] == self.codesz
166
  Xqidx = (2 * X @ grid.T - grid_norm).argmax(-1)
167
  return grid[Xqidx], Xqidx
168
 
169
- def fast_quantize_part(self, X):
170
- X_part = torch.abs(X)
171
- X_odd = torch.where((X < 0).sum(dim=-1) % 2 != 0)[0]
172
- X_part[X_odd, 0] = -X_part[X_odd, 0]
173
- mask = 1 - 2 * (X < 0).to(torch.float32)
174
- mask[X_odd, 0] = -mask[X_odd, 0]
175
- roundout, Xqidx = self.round(X_part, self.grid_part, self.grid_part_norm)
176
- vals = roundout * mask
177
- real_idx = self.allcombo_idx[self.idx_map[mask2int((1 - mask) / 2, self.int_map)], Xqidx]
178
- err = (X - vals).norm(dim=-1)
179
- return vals, real_idx, err
180
-
181
  def quantize(self, X, return_idx=True):
182
- X_plus = X + 1 / 4 # quantize X to D8^ - 1/4
183
- X_minus = X - 1 / 4 # quantize X to D8^ + 1/4
184
-
185
- plus_vals, plus_idx, plus_err = self.fast_quantize_part(X_plus)
186
- minus_vals, minus_idx, minus_err = self.fast_quantize_part(X_minus)
187
- plus_idx = plus_idx + 2**15
188
-
189
- which = plus_err < minus_err
190
- final_vals = torch.where(which.unsqueeze(-1), plus_vals - 1 / 4, minus_vals + 1 / 4)
191
-
192
  if return_idx:
193
- final_idxs = self.grid_idx_map[torch.where(which, plus_idx, minus_idx)]
194
  return final_vals, final_idxs
195
-
196
  return final_vals
197
 
198
- def maybe_pack_idxs(self, idxs):
199
- return idxs
200
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
201
  def by_idxs(self, idxs, **kwargs):
202
- return self.grid[self.grid_idx_inv[idxs.int() - self.idx_offset]]
 
 
 
 
 
 
203
 
204
 
205
  class QuantizedE8P12Linear(nn.Module):
@@ -207,10 +182,6 @@ class QuantizedE8P12Linear(nn.Module):
207
  def __init__(self, device):
208
  super().__init__()
209
  self.codebook = E8P12_codebook(inference=True).to(torch.float16).to(device)
210
- self.codebook_matvec = torch.zeros((256, ), dtype=torch.int64, device=device)
211
- for i in range(8):
212
- chunk = (self.codebook.grid_abs[:, i] * 4).to(torch.int64)
213
- self.codebook_matvec |= chunk << (i * 8)
214
 
215
  def forward(self,
216
  input,
@@ -228,9 +199,9 @@ class QuantizedE8P12Linear(nn.Module):
228
  rescale_WH=False,
229
  scaleWH=None,
230
  **kwargs):
231
- (m, n) = Qidxs.shape
232
 
233
- x = input.view(-1, n * _E8P_CODESZ).to(torch.float32)
234
  if rescale_WH:
235
  x /= scaleWH
236
  x = x * SU
@@ -240,17 +211,17 @@ class QuantizedE8P12Linear(nn.Module):
240
  Bx = x @ B.t().to(torch.float32)
241
  ABx = Bx @ A.t().to(torch.float32)
242
 
243
- # TODO: find the optimal threshold
244
- if x.size(0) < 6:
245
- x = quiptools_cuda.decode_matmul_e8p(x, Qidxs - 0x8000,
246
- self.codebook_matvec).to(torch.float32)
 
 
247
  else:
248
- W_decompressed = torch.zeros(m,
249
- n * _E8P_CODESZ,
250
- device=Qidxs.device,
251
- dtype=torch.float16)
252
- quiptools_cuda.decompress_e8p_origorder(Qidxs, self.codebook.grid_abs,
253
- self.codebook.grid_abs_even, W_decompressed)
254
  x = (x.to(torch.float16) @ W_decompressed.T).to(torch.float32)
255
 
256
  x *= Wscale
 
6
  which makes 2^16 entries.
7
  This corresponds to a subset of E8 + 1/4
8
  """
 
9
  import torch
10
  import math
11
  from torch import nn
 
21
  def int2mask(i, int_map):
22
  return ((i & int_map) > 0).int()
23
 
 
24
  def mask2int(mask, int_map):
25
  return (int_map.unsqueeze(0) * mask.int()).sum(dim=-1)
26
 
27
+ def get_norm12():
28
+ # 29 elements of norm 12 in E8 + 1/4
29
+ return torch.tensor([
 
 
 
 
 
 
30
  [3, 1, 1, 1, 3, 3, 3, 3],
31
  [1, 3, 1, 1, 3, 3, 3, 3],
32
  [1, 1, 3, 1, 3, 3, 3, 3],
 
54
  [1, 3, 3, 3, 1, 3, 3, 1],
55
  [1, 3, 3, 3, 3, 1, 1, 3],
56
  [1, 3, 3, 3, 1, 3, 1, 3],
57
+ [1, 1, 3, 3, 1, 3, 3, 3],
58
  [3, 3, 1, 1, 3, 3, 3, 1],
59
  ]) / 2
60
+
61
+
62
+ def get_packed_abs_grid():
63
+ intr = torch.arange(-4, 4)
64
+ d8 = torch.cartesian_prod(*[intr] * 8).float() + 1 / 2
65
+ d8m2 = (d8.sum(dim=-1) % 2 == 0)
66
+ d8n = d8.norm(dim=-1)**2 <= 10
67
+ d8abs = torch.unique(d8[sorted(torch.where(d8m2 * d8n)[0])].abs(), dim=0)
68
+ norm12 = get_norm12()
69
+ cba = torch.concat([d8abs, norm12], dim=0)
70
+ cba = cba[:, [0, 2, 4, 6, 1, 3, 5, 7]]
71
+ cba[:,7] *= (1 - 2 * (cba.sum(1) % 2))
72
+ cba = cba * 2 + 8
73
+ cba = cba.to(torch.int32)
74
+ acc = cba[:,0]
75
+ for i in range(7):
76
+ acc = acc | (cba[:,(i+1)] << ((i+1)*4))
77
+ return acc
78
+
79
+
80
+ def get_abs_grid():
81
+ intr = torch.arange(-4, 4)
82
+ d8 = torch.cartesian_prod(*[intr] * _E8P_CODESZ).float() + 1 / 2
83
+ d8m2 = (d8.sum(dim=-1) % 2 == 0)
84
+ d8n = d8.norm(dim=-1)**2 <= 10
85
+ d8abs = torch.unique(d8[sorted(torch.where(d8m2 * d8n)[0])].abs(), dim=0)
86
+ norm12 = get_norm12()
87
  return torch.concat([d8abs, norm12], dim=0)
88
 
89
 
90
+ def get_full_grid(packed_abs_grid):
91
+ synth_codebook = torch.zeros(1 << 16, 8)
92
+ shuffle_map = [0,4,1,5,2,6,3,7]
93
+ for c in range(1 << 16):
94
+ signs = c & 255
95
+ abs = c >> 8
96
+ parity = 0
97
+ for i in range(8):
98
+ parity = parity ^ ((signs >> i) & 1)
99
+ signs = signs ^ parity
100
+ abs_code = packed_abs_grid[abs].item()
101
+ for i in range(8):
102
+ ii = shuffle_map[i]
103
+ synth_codebook[c,i] = (((abs_code >> (4 * ii)) & 15) - 8) * 0.5
104
+ if ((signs >> ii) & 1):
105
+ synth_codebook[c,i] *= -1
106
+ if parity:
107
+ synth_codebook[c,:] -= 0.25
108
+ else:
109
+ synth_codebook[c,:] += 0.25
110
+ return synth_codebook, torch.arange(1 << 16)
111
+
112
+ _E8P_PACKED_ABS_CACHED = get_packed_abs_grid()
113
+ _E8P_GRID, _E8P_GRID_IDX = get_full_grid(_E8P_PACKED_ABS_CACHED)
 
 
 
 
 
 
 
 
114
 
115
 
116
  class E8P12_codebook(nn.Module):
117
 
118
  def __init__(self, inference=False):
119
  super(E8P12_codebook, self).__init__()
120
+ self.opt_scale = 1.03
121
  self.codesz = _E8P_CODESZ
122
+ self.idx_dtype = torch.int64
123
+ self.packsz = 4
 
124
  self.pack_out = False
125
+ self.version = 1
126
 
127
+ self.register_buffer('grid_packed_abs', _E8P_PACKED_ABS_CACHED)
 
128
 
129
  if not inference:
 
130
  self.register_buffer('grid', _E8P_GRID)
131
+ self.register_buffer('grid_norm', _E8P_GRID.norm(dim=-1)**2)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
132
  '''
133
  self.to('cuda')
134
  samples = torch.distributions.multivariate_normal.MultivariateNormal(torch.zeros(8), torch.eye(8)).rsample([2000000]).cuda()
 
137
  exit()
138
  '''
139
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
140
  def round(self, X, grid, grid_norm):
141
  assert X.shape[-1] == self.codesz
142
  Xqidx = (2 * X @ grid.T - grid_norm).argmax(-1)
143
  return grid[Xqidx], Xqidx
144
 
 
 
 
 
 
 
 
 
 
 
 
 
145
  def quantize(self, X, return_idx=True):
146
+ final_vals, final_idxs = self.round(X, self.grid, self.grid_norm)
 
 
 
 
 
 
 
 
 
147
  if return_idx:
 
148
  return final_vals, final_idxs
 
149
  return final_vals
150
 
151
+ def maybe_pack_idxs(self, idxs):
152
+ m, n = idxs.shape
153
+ idxs = idxs.view(m//2, 2, (n*8)//16, 2).transpose(1, 2).contiguous()
154
+
155
+ abs32 = (idxs[:, :, 0, 0] >> 8) + \
156
+ ((idxs[:, :, 1, 0] >> 8) << 8) + \
157
+ ((idxs[:, :, 0, 1] >> 8) << 16) + \
158
+ ((idxs[:, :, 1, 1] >> 8) << 24)
159
+
160
+ sign32 = torch.zeros(abs32.shape, dtype=abs32.dtype, device=abs32.device)
161
+ for i in range(4):
162
+ wt = idxs[:, :, i % 2, i // 2]
163
+ for j in range(8):
164
+ sign32 += ((wt >> j) & 1) << (4*j + i)
165
+
166
+ output = (sign32 << 32) + abs32
167
+ output = output.reshape(m//16, 8, n//8, 4).transpose(1, 2).contiguous()
168
+ return output.view(m, n//4)
169
+
170
  def by_idxs(self, idxs, **kwargs):
171
+ m, n = idxs.shape
172
+ W_decompressed = quiptools_cuda.decompress_packed_e8p(
173
+ idxs.view(m//16, n//2, 8, 4),
174
+ self.grid_packed_abs
175
+ )
176
+ return W_decompressed
177
+
178
 
179
 
180
  class QuantizedE8P12Linear(nn.Module):
 
182
  def __init__(self, device):
183
  super().__init__()
184
  self.codebook = E8P12_codebook(inference=True).to(torch.float16).to(device)
 
 
 
 
185
 
186
  def forward(self,
187
  input,
 
199
  rescale_WH=False,
200
  scaleWH=None,
201
  **kwargs):
202
+ n, m = len(SU), len(SV)
203
 
204
+ x = input.view(-1, n).to(torch.float32)
205
  if rescale_WH:
206
  x /= scaleWH
207
  x = x * SU
 
211
  Bx = x @ B.t().to(torch.float32)
212
  ABx = Bx @ A.t().to(torch.float32)
213
 
214
+ if x.size(0) == 1:
215
+ x = quiptools_cuda.decode_matvec_e8p(
216
+ x[0].to(torch.float16),
217
+ Qidxs.view(m//16, n//64, 8, 4),
218
+ self.codebook.grid_packed_abs
219
+ ).to(torch.float32)
220
  else:
221
+ W_decompressed = quiptools_cuda.decompress_packed_e8p(
222
+ Qidxs.view(m//16, n//64, 8, 4),
223
+ self.codebook.grid_packed_abs
224
+ )
 
 
225
  x = (x.to(torch.float16) @ W_decompressed.T).to(torch.float32)
226
 
227
  x *= Wscale
quip-sharp/lib/linear/__pycache__/__init__.cpython-310.pyc CHANGED
Binary files a/quip-sharp/lib/linear/__pycache__/__init__.cpython-310.pyc and b/quip-sharp/lib/linear/__pycache__/__init__.cpython-310.pyc differ
 
quip-sharp/lib/linear/__pycache__/fused_quantized_linear.cpython-310.pyc ADDED
Binary file (1.43 kB). View file
 
quip-sharp/lib/linear/__pycache__/quantized_linear.cpython-310.pyc CHANGED
Binary files a/quip-sharp/lib/linear/__pycache__/quantized_linear.cpython-310.pyc and b/quip-sharp/lib/linear/__pycache__/quantized_linear.cpython-310.pyc differ
 
quip-sharp/lib/linear/fused_quantized_linear.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import quiptools_cuda
4
+ from lib.utils import dtype_from_str, get_hadK
5
+ from lib import codebook
6
+ from .quantized_linear import QuantizedLinear
7
+ import time
8
+
9
+
10
+ class FusedQuantizedLinear(QuantizedLinear):
11
+
12
+ def __init__(self, fuse_dim, fuse_sizes, *QL_args, **QL_kwargs):
13
+ super(FusedQuantizedLinear, self).__init__(*QL_args, **QL_kwargs)
14
+ self.fuse_dim = fuse_dim
15
+ self.fuse_sizes = fuse_sizes
16
+ self.register_buffer('fuse_scales', torch.ones(len(self.fuse_sizes)))
17
+ self.n = len(self.fuse_sizes)
18
+
19
+ def forward(self, input):
20
+ fused_output = super(FusedQuantizedLinear, self).forward(input)
21
+ split_outputs = torch.split(fused_output, self.fuse_sizes, self.fuse_dim)
22
+ return tuple(split_outputs[i] * self.fuse_scales[i] for i in range(self.n))
quip-sharp/lib/linear/quantized_linear.py CHANGED
@@ -18,7 +18,8 @@ class QuantizedLinear(nn.Module):
18
  codebook_version,
19
  outlier_channel_split=False,
20
  rank=-1,
21
- rescale_WH=False):
 
22
  super().__init__()
23
 
24
  self.in_features = in_features
@@ -27,6 +28,10 @@ class QuantizedLinear(nn.Module):
27
  self.rank = rank
28
  self.rescale_WH = rescale_WH
29
 
 
 
 
 
30
  if self.outlier_channel_split:
31
  self.register_buffer('ocs_dupe_inds', torch.arange(in_features))
32
 
@@ -87,18 +92,22 @@ class QuantizedLinear(nn.Module):
87
  if self.outlier_channel_split:
88
  input = input[..., self.ocs_dupe_inds]
89
 
90
- return self.codebook_class(input,
91
- self.Qidxs,
92
- self.SU,
93
- self.SV,
94
- self.Wscale,
95
- self.had_left,
96
- self.had_right,
97
- self.K_left,
98
- self.K_right,
99
- rank=self.rank,
100
- A=self.A,
101
- B=self.B,
102
- rescale_WH=self.rescale_WH,
103
- scaleWH=self.scaleWH,
104
- packed=self.packed)
 
 
 
 
 
18
  codebook_version,
19
  outlier_channel_split=False,
20
  rank=-1,
21
+ rescale_WH=False,
22
+ bias=False):
23
  super().__init__()
24
 
25
  self.in_features = in_features
 
28
  self.rank = rank
29
  self.rescale_WH = rescale_WH
30
 
31
+ self.has_bias = bias
32
+ if self.has_bias:
33
+ self.register_buffer('bias', torch.ones(out_features))
34
+
35
  if self.outlier_channel_split:
36
  self.register_buffer('ocs_dupe_inds', torch.arange(in_features))
37
 
 
92
  if self.outlier_channel_split:
93
  input = input[..., self.ocs_dupe_inds]
94
 
95
+ result = self.codebook_class(input,
96
+ self.Qidxs,
97
+ self.SU,
98
+ self.SV,
99
+ self.Wscale,
100
+ self.had_left,
101
+ self.had_right,
102
+ self.K_left,
103
+ self.K_right,
104
+ rank=self.rank,
105
+ A=self.A,
106
+ B=self.B,
107
+ rescale_WH=self.rescale_WH,
108
+ scaleWH=self.scaleWH,
109
+ packed=self.packed)
110
+ if self.has_bias:
111
+ return result + self.bias
112
+ return result
113
+
quip-sharp/lib/utils/__pycache__/__init__.cpython-310.pyc CHANGED
Binary files a/quip-sharp/lib/utils/__pycache__/__init__.cpython-310.pyc and b/quip-sharp/lib/utils/__pycache__/__init__.cpython-310.pyc differ
 
quip-sharp/lib/utils/__pycache__/data_utils.cpython-310.pyc CHANGED
Binary files a/quip-sharp/lib/utils/__pycache__/data_utils.cpython-310.pyc and b/quip-sharp/lib/utils/__pycache__/data_utils.cpython-310.pyc differ
 
quip-sharp/lib/utils/__pycache__/lm_eval_adaptor.cpython-310.pyc CHANGED
Binary files a/quip-sharp/lib/utils/__pycache__/lm_eval_adaptor.cpython-310.pyc and b/quip-sharp/lib/utils/__pycache__/lm_eval_adaptor.cpython-310.pyc differ
 
quip-sharp/lib/utils/__pycache__/math_utils.cpython-310.pyc CHANGED
Binary files a/quip-sharp/lib/utils/__pycache__/math_utils.cpython-310.pyc and b/quip-sharp/lib/utils/__pycache__/math_utils.cpython-310.pyc differ
 
quip-sharp/lib/utils/__pycache__/matmul_had.cpython-310.pyc CHANGED
Binary files a/quip-sharp/lib/utils/__pycache__/matmul_had.cpython-310.pyc and b/quip-sharp/lib/utils/__pycache__/matmul_had.cpython-310.pyc differ
 
quip-sharp/lib/utils/__pycache__/matmul_kron.cpython-310.pyc CHANGED
Binary files a/quip-sharp/lib/utils/__pycache__/matmul_kron.cpython-310.pyc and b/quip-sharp/lib/utils/__pycache__/matmul_kron.cpython-310.pyc differ
 
quip-sharp/lib/utils/__pycache__/misc.cpython-310.pyc CHANGED
Binary files a/quip-sharp/lib/utils/__pycache__/misc.cpython-310.pyc and b/quip-sharp/lib/utils/__pycache__/misc.cpython-310.pyc differ
 
quip-sharp/lib/utils/__pycache__/unsafe_import.cpython-310.pyc CHANGED
Binary files a/quip-sharp/lib/utils/__pycache__/unsafe_import.cpython-310.pyc and b/quip-sharp/lib/utils/__pycache__/unsafe_import.cpython-310.pyc differ
 
quip-sharp/lib/utils/data_utils.py CHANGED
@@ -58,7 +58,6 @@ def block_LDL(H, b):
58
  def wrap_tokenizer(tokenizer, x, ctx_size):
59
  return tokenizer(x, return_tensors='pt', truncation=True, padding=True, max_length=ctx_size)
60
 
61
-
62
  def sample_devset(dataset, tokenizer, size=128, ctx_size=2048, nproc=1):
63
  devset = torch.zeros((size, ctx_size), dtype=torch.int64)
64
  saved = 0
@@ -122,6 +121,7 @@ def load_quip(save_name, cb, args, device):
122
 
123
  def dtype_from_str(str):
124
  dtype_map = {
 
125
  'torch.int32': torch.int32,
126
  'torch.int16': torch.int16,
127
  'torch.uint8': torch.uint8,
 
58
  def wrap_tokenizer(tokenizer, x, ctx_size):
59
  return tokenizer(x, return_tensors='pt', truncation=True, padding=True, max_length=ctx_size)
60
 
 
61
  def sample_devset(dataset, tokenizer, size=128, ctx_size=2048, nproc=1):
62
  devset = torch.zeros((size, ctx_size), dtype=torch.int64)
63
  saved = 0
 
121
 
122
  def dtype_from_str(str):
123
  dtype_map = {
124
+ 'torch.int64': torch.int64,
125
  'torch.int32': torch.int32,
126
  'torch.int16': torch.int16,
127
  'torch.uint8': torch.uint8,
quip-sharp/lib/utils/unsafe_import.py CHANGED
@@ -2,7 +2,6 @@
2
 
3
  from model.graph_wrapper import get_graph_wrapper
4
  from model.llama import LlamaForCausalLM as llama_fuse
5
- from model.llama_nofuse import LlamaForCausalLM as llama_nofuse
6
  from model.mistral import MistralForCausalLM
7
  import json
8
  import os
@@ -17,10 +16,9 @@ def model_from_hf_path(path, use_cuda_graph=True, use_flash_attn=True):
17
  is_quantized = hasattr(bad_config, 'quip_params')
18
  model_type = bad_config.model_type
19
  if is_quantized:
20
- fused = bad_config.quip_params.get('fused', True)
21
  if model_type == 'llama':
22
  model_str = transformers.LlamaConfig.from_pretrained(path)._name_or_path
23
- model_cls = llama_fuse if fused else llama_nofuse
24
  elif model_type == 'mistral':
25
  model_str = transformers.MistralConfig.from_pretrained(path)._name_or_path
26
  model_cls = MistralForCausalLM
 
2
 
3
  from model.graph_wrapper import get_graph_wrapper
4
  from model.llama import LlamaForCausalLM as llama_fuse
 
5
  from model.mistral import MistralForCausalLM
6
  import json
7
  import os
 
16
  is_quantized = hasattr(bad_config, 'quip_params')
17
  model_type = bad_config.model_type
18
  if is_quantized:
 
19
  if model_type == 'llama':
20
  model_str = transformers.LlamaConfig.from_pretrained(path)._name_or_path
21
+ model_cls = llama_fuse
22
  elif model_type == 'mistral':
23
  model_str = transformers.MistralConfig.from_pretrained(path)._name_or_path
24
  model_cls = MistralForCausalLM
quip-sharp/model/__pycache__/graph_wrapper.cpython-310.pyc CHANGED
Binary files a/quip-sharp/model/__pycache__/graph_wrapper.cpython-310.pyc and b/quip-sharp/model/__pycache__/graph_wrapper.cpython-310.pyc differ
 
quip-sharp/model/__pycache__/llama.cpython-310.pyc CHANGED
Binary files a/quip-sharp/model/__pycache__/llama.cpython-310.pyc and b/quip-sharp/model/__pycache__/llama.cpython-310.pyc differ
 
quip-sharp/model/__pycache__/mistral.cpython-310.pyc CHANGED
Binary files a/quip-sharp/model/__pycache__/mistral.cpython-310.pyc and b/quip-sharp/model/__pycache__/mistral.cpython-310.pyc differ
 
quip-sharp/model/__pycache__/version.cpython-310.pyc CHANGED
Binary files a/quip-sharp/model/__pycache__/version.cpython-310.pyc and b/quip-sharp/model/__pycache__/version.cpython-310.pyc differ
 
quip-sharp/model/llama.py CHANGED
@@ -48,6 +48,7 @@ if is_flash_attn_available():
48
  from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
49
 
50
  from lib.linear.quantized_linear import QuantizedLinear
 
51
  from .version import check_model_version
52
 
53
  logger = logging.get_logger(__name__)
@@ -225,15 +226,17 @@ class LlamaMLP(nn.Module):
225
  self.config = config
226
  self.hidden_size = config.hidden_size
227
  self.intermediate_size = config.intermediate_size
228
- self.upgate_proj = QuantizedLinear(self.hidden_size,
229
- self.intermediate_size * 2,
230
- config.quip_params['codesz'],
231
- config.quip_params.get('packsz', 1),
232
- config.quip_params.get('pack_out', False),
233
- config.quip_params['idx_dtype'],
234
- config.quip_params.get('codebook_version', 0),
235
- rank=config.quip_params['lora_rank'],
236
- rescale_WH=config.quip_params['rescale_WH'])
 
 
237
  self.down_proj = QuantizedLinear(
238
  self.config.quip_params['ocs_down_size'] if \
239
  self.config.quip_params['outlier_channel_split'] else self.intermediate_size,
@@ -246,24 +249,14 @@ class LlamaMLP(nn.Module):
246
  outlier_channel_split=self.config.quip_params['outlier_channel_split'],
247
  rank=self.config.quip_params['lora_rank'],
248
  rescale_WH=self.config.quip_params['rescale_WH'])
249
- self.register_buffer('up_scale', nn.Parameter(torch.ones(())))
250
- self.register_buffer('gate_scale', nn.Parameter(torch.ones(())))
251
- self.register_buffer('down_scale', nn.Parameter(torch.ones(())))
252
  self.act_fn = ACT2FN[config.hidden_act]
253
 
254
  def forward(self, x):
255
  if self.config.pretraining_tp > 1:
256
  raise Exception
257
- else:
258
- upgate_proj = self.upgate_proj(x.to(torch.float32))
259
- up_proj = self.up_scale * upgate_proj[...,
260
- 0:self.intermediate_size]
261
- gate_proj = self.gate_scale * upgate_proj[
262
- ..., self.intermediate_size:(self.intermediate_size * 2)]
263
- down_proj = self.down_scale * self.down_proj(
264
- self.act_fn(gate_proj) * up_proj)
265
 
266
- return down_proj.half()
 
267
 
268
 
269
  def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
@@ -297,7 +290,12 @@ class LlamaAttention(nn.Module):
297
  f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
298
  f" and `num_heads`: {self.num_heads})."
299
  )
300
- self.qkv_proj = QuantizedLinear(
 
 
 
 
 
301
  self.hidden_size, (self.num_heads * self.head_dim) +
302
  (self.num_key_value_heads * self.head_dim) +
303
  (self.num_key_value_heads * self.head_dim),
@@ -308,7 +306,7 @@ class LlamaAttention(nn.Module):
308
  config.quip_params.get('codebook_version', 0),
309
  rank=config.quip_params['lora_rank'],
310
  rescale_WH=config.quip_params['rescale_WH'])
311
-
312
  self.o_proj = QuantizedLinear(self.num_heads * self.head_dim,
313
  self.hidden_size,
314
  config.quip_params['codesz'],
@@ -319,10 +317,6 @@ class LlamaAttention(nn.Module):
319
  rank=config.quip_params['lora_rank'],
320
  rescale_WH=config.quip_params['rescale_WH'])
321
 
322
- self.register_buffer('q_scale', nn.Parameter(torch.ones(())))
323
- self.register_buffer('k_scale', nn.Parameter(torch.ones(())))
324
- self.register_buffer('v_scale', nn.Parameter(torch.ones(())))
325
- self.register_buffer('o_scale', nn.Parameter(torch.ones(())))
326
  self._init_rope()
327
 
328
  def _init_rope(self):
@@ -370,19 +364,7 @@ class LlamaAttention(nn.Module):
370
  if self.config.pretraining_tp > 1:
371
  assert (False)
372
  else:
373
- qkv_states = self.qkv_proj(hidden_states.to(torch.float32))
374
- query_states = self.q_scale * qkv_states[..., 0:(self.num_heads *
375
- self.head_dim)]
376
- key_states = self.k_scale * qkv_states[..., (
377
- self.num_heads * self.head_dim):(
378
- (self.num_heads * self.head_dim) +
379
- (self.num_key_value_heads * self.head_dim))]
380
- value_states = self.v_scale * qkv_states[..., (
381
- (self.num_heads * self.head_dim) +
382
- (self.num_key_value_heads * self.head_dim)):(
383
- (self.num_heads * self.head_dim) +
384
- (self.num_key_value_heads * self.head_dim) +
385
- (self.num_key_value_heads * self.head_dim))]
386
  query_states = query_states.half()
387
  key_states = key_states.half()
388
  value_states = value_states.half()
@@ -439,7 +421,7 @@ class LlamaAttention(nn.Module):
439
  if self.config.pretraining_tp > 1:
440
  assert (False)
441
  else:
442
- attn_output = (self.o_scale * self.o_proj(attn_output)).half()
443
 
444
  if not output_attentions:
445
  attn_weights = None
@@ -468,19 +450,7 @@ class LlamaFlashAttention2(LlamaAttention):
468
  output_attentions = False
469
 
470
  bsz, q_len, _ = hidden_states.size()
471
- qkv_states = self.qkv_proj(hidden_states.to(torch.float32))
472
- query_states = self.q_scale * qkv_states[..., 0:(self.num_heads *
473
- self.head_dim)]
474
- key_states = self.k_scale * qkv_states[..., (
475
- self.num_heads * self.head_dim):(
476
- (self.num_heads * self.head_dim) +
477
- (self.num_key_value_heads * self.head_dim))]
478
- value_states = self.v_scale * qkv_states[..., (
479
- (self.num_heads * self.head_dim) +
480
- (self.num_key_value_heads * self.head_dim)):(
481
- (self.num_heads * self.head_dim) +
482
- (self.num_key_value_heads * self.head_dim) +
483
- (self.num_key_value_heads * self.head_dim))]
484
  query_states = query_states.half()
485
  key_states = key_states.half()
486
  value_states = value_states.half()
@@ -538,7 +508,7 @@ class LlamaFlashAttention2(LlamaAttention):
538
  )
539
 
540
  attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
541
- attn_output = (self.o_scale * self.o_proj(attn_output)).half()
542
 
543
  if not output_attentions:
544
  attn_weights = None
 
48
  from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
49
 
50
  from lib.linear.quantized_linear import QuantizedLinear
51
+ from lib.linear.fused_quantized_linear import FusedQuantizedLinear
52
  from .version import check_model_version
53
 
54
  logger = logging.get_logger(__name__)
 
226
  self.config = config
227
  self.hidden_size = config.hidden_size
228
  self.intermediate_size = config.intermediate_size
229
+ self.upgate_proj = FusedQuantizedLinear(
230
+ -1, (self.intermediate_size, self.intermediate_size),
231
+ self.hidden_size,
232
+ self.intermediate_size * 2,
233
+ config.quip_params['codesz'],
234
+ config.quip_params.get('packsz', 1),
235
+ config.quip_params.get('pack_out', False),
236
+ config.quip_params['idx_dtype'],
237
+ config.quip_params.get('codebook_version', 0),
238
+ rank=config.quip_params['lora_rank'],
239
+ rescale_WH=config.quip_params['rescale_WH'])
240
  self.down_proj = QuantizedLinear(
241
  self.config.quip_params['ocs_down_size'] if \
242
  self.config.quip_params['outlier_channel_split'] else self.intermediate_size,
 
249
  outlier_channel_split=self.config.quip_params['outlier_channel_split'],
250
  rank=self.config.quip_params['lora_rank'],
251
  rescale_WH=self.config.quip_params['rescale_WH'])
 
 
 
252
  self.act_fn = ACT2FN[config.hidden_act]
253
 
254
  def forward(self, x):
255
  if self.config.pretraining_tp > 1:
256
  raise Exception
 
 
 
 
 
 
 
 
257
 
258
+ up_proj, gate_proj = self.upgate_proj(x.to(torch.float32))
259
+ return self.down_proj(self.act_fn(gate_proj) * up_proj).half()
260
 
261
 
262
  def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
 
290
  f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
291
  f" and `num_heads`: {self.num_heads})."
292
  )
293
+
294
+ self.qkv_proj = FusedQuantizedLinear(
295
+ -1,
296
+ (self.num_heads*self.head_dim,
297
+ self.num_key_value_heads*self.head_dim,
298
+ self.num_key_value_heads*self.head_dim),
299
  self.hidden_size, (self.num_heads * self.head_dim) +
300
  (self.num_key_value_heads * self.head_dim) +
301
  (self.num_key_value_heads * self.head_dim),
 
306
  config.quip_params.get('codebook_version', 0),
307
  rank=config.quip_params['lora_rank'],
308
  rescale_WH=config.quip_params['rescale_WH'])
309
+
310
  self.o_proj = QuantizedLinear(self.num_heads * self.head_dim,
311
  self.hidden_size,
312
  config.quip_params['codesz'],
 
317
  rank=config.quip_params['lora_rank'],
318
  rescale_WH=config.quip_params['rescale_WH'])
319
 
 
 
 
 
320
  self._init_rope()
321
 
322
  def _init_rope(self):
 
364
  if self.config.pretraining_tp > 1:
365
  assert (False)
366
  else:
367
+ query_states, key_states, value_states = self.qkv_proj(hidden_states.to(torch.float32))
 
 
 
 
 
 
 
 
 
 
 
 
368
  query_states = query_states.half()
369
  key_states = key_states.half()
370
  value_states = value_states.half()
 
421
  if self.config.pretraining_tp > 1:
422
  assert (False)
423
  else:
424
+ attn_output = self.o_proj(attn_output).half()
425
 
426
  if not output_attentions:
427
  attn_weights = None
 
450
  output_attentions = False
451
 
452
  bsz, q_len, _ = hidden_states.size()
453
+ query_states, key_states, value_states = self.qkv_proj(hidden_states.to(torch.float32))
 
 
 
 
 
 
 
 
 
 
 
 
454
  query_states = query_states.half()
455
  key_states = key_states.half()
456
  value_states = value_states.half()
 
508
  )
509
 
510
  attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
511
+ attn_output = self.o_proj(attn_output).half()
512
 
513
  if not output_attentions:
514
  attn_weights = None
quip-sharp/model/mistral.py CHANGED
@@ -48,6 +48,7 @@ if is_flash_attn_available():
48
  _flash_supports_window_size = "window_size" in list(inspect.signature(flash_attn_func).parameters)
49
 
50
  from lib.linear.quantized_linear import QuantizedLinear
 
51
  from .version import check_model_version
52
 
53
  logger = logging.get_logger(__name__)
@@ -192,15 +193,17 @@ class MistralMLP(nn.Module):
192
  self.hidden_size = config.hidden_size
193
  self.intermediate_size = config.intermediate_size
194
 
195
- self.upgate_proj = QuantizedLinear(self.hidden_size,
196
- self.intermediate_size * 2,
197
- config.quip_params['codesz'],
198
- config.quip_params.get('packsz', 1),
199
- config.quip_params.get('pack_out', False),
200
- config.quip_params['idx_dtype'],
201
- config.quip_params.get('codebook_version', 0),
202
- rank=config.quip_params['lora_rank'],
203
- rescale_WH=config.quip_params['rescale_WH'])
 
 
204
  self.down_proj = QuantizedLinear(
205
  self.config.quip_params['ocs_down_size'] if \
206
  self.config.quip_params['outlier_channel_split'] else self.intermediate_size,
@@ -213,20 +216,11 @@ class MistralMLP(nn.Module):
213
  outlier_channel_split=self.config.quip_params['outlier_channel_split'],
214
  rank=self.config.quip_params['lora_rank'],
215
  rescale_WH=self.config.quip_params['rescale_WH'])
216
- self.register_buffer('up_scale', nn.Parameter(torch.ones(())))
217
- self.register_buffer('gate_scale', nn.Parameter(torch.ones(())))
218
- self.register_buffer('down_scale', nn.Parameter(torch.ones(())))
219
  self.act_fn = ACT2FN[config.hidden_act]
220
 
221
  def forward(self, x):
222
- upgate_proj = self.upgate_proj(x.to(torch.float32))
223
- up_proj = self.up_scale * upgate_proj[...,
224
- 0:self.intermediate_size]
225
- gate_proj = self.gate_scale * upgate_proj[
226
- ..., self.intermediate_size:(self.intermediate_size * 2)]
227
- down_proj = self.down_scale * self.down_proj(
228
- self.act_fn(gate_proj) * up_proj)
229
- return down_proj.half()
230
 
231
 
232
  def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
@@ -264,7 +258,11 @@ class MistralAttention(nn.Module):
264
  f" and `num_heads`: {self.num_heads})."
265
  )
266
 
267
- self.qkv_proj = QuantizedLinear(
 
 
 
 
268
  self.hidden_size, (self.num_heads * self.head_dim) +
269
  (self.num_key_value_heads * self.head_dim) +
270
  (self.num_key_value_heads * self.head_dim),
@@ -286,11 +284,6 @@ class MistralAttention(nn.Module):
286
  rank=config.quip_params['lora_rank'],
287
  rescale_WH=config.quip_params['rescale_WH'])
288
 
289
- self.register_buffer('q_scale', nn.Parameter(torch.ones(())))
290
- self.register_buffer('k_scale', nn.Parameter(torch.ones(())))
291
- self.register_buffer('v_scale', nn.Parameter(torch.ones(())))
292
- self.register_buffer('o_scale', nn.Parameter(torch.ones(())))
293
-
294
  self.rotary_emb = MistralRotaryEmbedding(
295
  self.head_dim,
296
  max_position_embeddings=self.max_position_embeddings,
@@ -312,19 +305,7 @@ class MistralAttention(nn.Module):
312
  ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
313
  bsz, q_len, _ = hidden_states.size()
314
 
315
- qkv_states = self.qkv_proj(hidden_states.to(torch.float32))
316
- query_states = self.q_scale * qkv_states[..., 0:(self.num_heads *
317
- self.head_dim)]
318
- key_states = self.k_scale * qkv_states[..., (
319
- self.num_heads * self.head_dim):(
320
- (self.num_heads * self.head_dim) +
321
- (self.num_key_value_heads * self.head_dim))]
322
- value_states = self.v_scale * qkv_states[..., (
323
- (self.num_heads * self.head_dim) +
324
- (self.num_key_value_heads * self.head_dim)):(
325
- (self.num_heads * self.head_dim) +
326
- (self.num_key_value_heads * self.head_dim) +
327
- (self.num_key_value_heads * self.head_dim))]
328
  query_states = query_states.half()
329
  key_states = key_states.half()
330
  value_states = value_states.half()
@@ -379,7 +360,7 @@ class MistralAttention(nn.Module):
379
  attn_output = attn_output.transpose(1, 2).contiguous()
380
  attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
381
 
382
- attn_output = (self.o_scale * self.o_proj(attn_output)).half()
383
 
384
  if not output_attentions:
385
  attn_weights = None
@@ -406,19 +387,7 @@ class MistralFlashAttention2(MistralAttention):
406
  ):
407
  bsz, q_len, _ = hidden_states.size()
408
 
409
- qkv_states = self.qkv_proj(hidden_states.to(torch.float32))
410
- query_states = self.q_scale * qkv_states[..., 0:(self.num_heads *
411
- self.head_dim)]
412
- key_states = self.k_scale * qkv_states[..., (
413
- self.num_heads * self.head_dim):(
414
- (self.num_heads * self.head_dim) +
415
- (self.num_key_value_heads * self.head_dim))]
416
- value_states = self.v_scale * qkv_states[..., (
417
- (self.num_heads * self.head_dim) +
418
- (self.num_key_value_heads * self.head_dim)):(
419
- (self.num_heads * self.head_dim) +
420
- (self.num_key_value_heads * self.head_dim) +
421
- (self.num_key_value_heads * self.head_dim))]
422
  query_states = query_states.half()
423
  key_states = key_states.half()
424
  value_states = value_states.half()
@@ -517,7 +486,7 @@ class MistralFlashAttention2(MistralAttention):
517
  )
518
 
519
  attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
520
- attn_output = (self.o_scale * self.o_proj(attn_output)).half()
521
 
522
  if not output_attentions:
523
  attn_weights = None
 
48
  _flash_supports_window_size = "window_size" in list(inspect.signature(flash_attn_func).parameters)
49
 
50
  from lib.linear.quantized_linear import QuantizedLinear
51
+ from lib.linear.fused_quantized_linear import FusedQuantizedLinear
52
  from .version import check_model_version
53
 
54
  logger = logging.get_logger(__name__)
 
193
  self.hidden_size = config.hidden_size
194
  self.intermediate_size = config.intermediate_size
195
 
196
+ self.upgate_proj = FusedQuantizedLinear(
197
+ -1, (self.intermediate_size, self.intermediate_size),
198
+ self.hidden_size,
199
+ self.intermediate_size * 2,
200
+ config.quip_params['codesz'],
201
+ config.quip_params.get('packsz', 1),
202
+ config.quip_params.get('pack_out', False),
203
+ config.quip_params['idx_dtype'],
204
+ config.quip_params.get('codebook_version', 0),
205
+ rank=config.quip_params['lora_rank'],
206
+ rescale_WH=config.quip_params['rescale_WH'])
207
  self.down_proj = QuantizedLinear(
208
  self.config.quip_params['ocs_down_size'] if \
209
  self.config.quip_params['outlier_channel_split'] else self.intermediate_size,
 
216
  outlier_channel_split=self.config.quip_params['outlier_channel_split'],
217
  rank=self.config.quip_params['lora_rank'],
218
  rescale_WH=self.config.quip_params['rescale_WH'])
 
 
 
219
  self.act_fn = ACT2FN[config.hidden_act]
220
 
221
  def forward(self, x):
222
+ up_proj, gate_proj = self.upgate_proj(x.to(torch.float32))
223
+ return self.down_proj(self.act_fn(gate_proj) * up_proj).half()
 
 
 
 
 
 
224
 
225
 
226
  def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
 
258
  f" and `num_heads`: {self.num_heads})."
259
  )
260
 
261
+ self.qkv_proj = FusedQuantizedLinear(
262
+ -1,
263
+ (self.num_heads*self.head_dim,
264
+ self.num_key_value_heads*self.head_dim,
265
+ self.num_key_value_heads*self.head_dim),
266
  self.hidden_size, (self.num_heads * self.head_dim) +
267
  (self.num_key_value_heads * self.head_dim) +
268
  (self.num_key_value_heads * self.head_dim),
 
284
  rank=config.quip_params['lora_rank'],
285
  rescale_WH=config.quip_params['rescale_WH'])
286
 
 
 
 
 
 
287
  self.rotary_emb = MistralRotaryEmbedding(
288
  self.head_dim,
289
  max_position_embeddings=self.max_position_embeddings,
 
305
  ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
306
  bsz, q_len, _ = hidden_states.size()
307
 
308
+ query_states, key_states, value_states = self.qkv_proj(hidden_states.to(torch.float32))
 
 
 
 
 
 
 
 
 
 
 
 
309
  query_states = query_states.half()
310
  key_states = key_states.half()
311
  value_states = value_states.half()
 
360
  attn_output = attn_output.transpose(1, 2).contiguous()
361
  attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
362
 
363
+ attn_output = self.o_proj(attn_output).half()
364
 
365
  if not output_attentions:
366
  attn_weights = None
 
387
  ):
388
  bsz, q_len, _ = hidden_states.size()
389
 
390
+ query_states, key_states, value_states = self.qkv_proj(hidden_states.to(torch.float32))
 
 
 
 
 
 
 
 
 
 
 
 
391
  query_states = query_states.half()
392
  key_states = key_states.half()
393
  value_states = value_states.half()
 
486
  )
487
 
488
  attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
489
+ attn_output = self.o_proj(attn_output).half()
490
 
491
  if not output_attentions:
492
  attn_weights = None
quip-sharp/model/version.py CHANGED
@@ -1,8 +1,8 @@
1
- MODEL_VERSION = 0
2
 
3
  def check_model_version(test):
4
  if test != MODEL_VERSION:
5
  raise Exception(
6
  f"Saved model version ({test}) does not match the "\
7
  f"source code model version ({MODEL_VERSION}). "\
8
- "Please pull the latest code from git@github.com:Cornell-RelaxML/quip-sharp.git")
 
1
+ MODEL_VERSION = 1
2
 
3
  def check_model_version(test):
4
  if test != MODEL_VERSION:
5
  raise Exception(
6
  f"Saved model version ({test}) does not match the "\
7
  f"source code model version ({MODEL_VERSION}). "\
8
+ "Please pull the latest code or model checkpoints.")
quip-sharp/quantize_llama.py CHANGED
@@ -26,8 +26,8 @@ parser.add_argument('--num_cpu_threads', default=8, type=int)
26
  parser.add_argument('--batch_size', default=8, type=int)
27
  parser.add_argument('--devset_size', default=64, type=int)
28
  parser.add_argument('--ctx_size', default=2048, type=int)
29
- parser.add_argument('--save_path', default='checkpoints/quantized-hada-70b', type=str)
30
- parser.add_argument('--hessian_path', default='/share/desa/nfs01/quip_llama2/hessians', type=str)
31
  parser.add_argument('--base_model', default='meta-llama/Llama-2-70b-hf', type=str)
32
  parser.add_argument('--sigma_reg', default=1e-2, type=float)
33
  parser.add_argument('--sigma_reg2', default=1e-2, type=float)
@@ -286,7 +286,7 @@ def main(args):
286
  all_config['model_config'].quip_params['ocs_down_size'] = args.ocs_down_size
287
  torch.save(all_config, os.path.join(args.save_path, 'config.pt'))
288
 
289
- tokenizer = AutoTokenizer.from_pretrained(args.base_model, use_fast=False)
290
  tokenizer.pad_token = tokenizer.eos_token
291
  glog.info('loaded model')
292
 
 
26
  parser.add_argument('--batch_size', default=8, type=int)
27
  parser.add_argument('--devset_size', default=64, type=int)
28
  parser.add_argument('--ctx_size', default=2048, type=int)
29
+ parser.add_argument('--save_path', type=str)
30
+ parser.add_argument('--hessian_path', type=str)
31
  parser.add_argument('--base_model', default='meta-llama/Llama-2-70b-hf', type=str)
32
  parser.add_argument('--sigma_reg', default=1e-2, type=float)
33
  parser.add_argument('--sigma_reg2', default=1e-2, type=float)
 
286
  all_config['model_config'].quip_params['ocs_down_size'] = args.ocs_down_size
287
  torch.save(all_config, os.path.join(args.save_path, 'config.pt'))
288
 
289
+ tokenizer = AutoTokenizer.from_pretrained(args.base_model)
290
  tokenizer.pad_token = tokenizer.eos_token
291
  glog.info('loaded model')
292
 
quip-sharp/quiptools/build/lib.linux-x86_64-cpython-310/quiptools_cuda.cpython-310-x86_64-linux-gnu.so CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:899ab752e90b0cf2fbb9548cf4017ad053a8dda89b5cc78a765b72b3703bc11b
3
- size 13026208
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b869e88b6457109857b32ebb2ba424ebb6cdc53ecc3f856ef881709a4fbccf85
3
+ size 12982144
quip-sharp/quiptools/build/temp.linux-x86_64-cpython-310/.ninja_deps CHANGED
Binary files a/quip-sharp/quiptools/build/temp.linux-x86_64-cpython-310/.ninja_deps and b/quip-sharp/quiptools/build/temp.linux-x86_64-cpython-310/.ninja_deps differ
 
quip-sharp/quiptools/build/temp.linux-x86_64-cpython-310/.ninja_log CHANGED
@@ -1,6 +1,4 @@
1
  # ninja log v5
2
- 0 18168 1703587582805727136 /run/media/knut/HD/text-generation-webui/repositories/quip-sharp/quiptools/build/temp.linux-x86_64-cpython-310/quiptools_wrapper.o 1b1606004175d38f
3
- 9 19979 1703587632810492590 /run/media/knut/HD/text-generation-webui/repositories/quip-sharp/quiptools/build/temp.linux-x86_64-cpython-310/quiptools_wrapper.o c55be518cf9b4c1e
4
- 8 18153 1703587706965942532 /run/media/knut/HD/text-generation-webui/repositories/quip-sharp/quiptools/build/temp.linux-x86_64-cpython-310/quiptools_wrapper.o 1b1606004175d38f
5
- 8 43545 1703587732366665742 /run/media/knut/HD/text-generation-webui/repositories/quip-sharp/quiptools/build/temp.linux-x86_64-cpython-310/quiptools.o f601b9f154f8bde0
6
- 8 47187 1703587736006769314 /run/media/knut/HD/text-generation-webui/repositories/quip-sharp/quiptools/build/temp.linux-x86_64-cpython-310/quiptools_e8p_gemv.o 9d441ce55de572ae
 
1
  # ninja log v5
2
+ 0 19060 1704135981196240899 /run/media/knut/HD/text-generation-webui/repositories/quip-sharp/quiptools/build/temp.linux-x86_64-cpython-310/quiptools_wrapper.o 1b1606004175d38f
3
+ 0 43897 1704136006046927545 /run/media/knut/HD/text-generation-webui/repositories/quip-sharp/quiptools/build/temp.linux-x86_64-cpython-310/quiptools.o f601b9f154f8bde0
4
+ 0 46532 1704136008677000213 /run/media/knut/HD/text-generation-webui/repositories/quip-sharp/quiptools/build/temp.linux-x86_64-cpython-310/quiptools_e8p_gemv.o 9d441ce55de572ae
 
 
quip-sharp/quiptools/build/temp.linux-x86_64-cpython-310/quiptools.o CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:57c7f4e07515bed1baa410f842db9703d80c82b25dd0d69588266772eebe746c
3
  size 2174288
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:874d7a979e49a3dcecaec7c56b4072f8f055e60a4560f5225ad1c385ce38607c
3
  size 2174288
quip-sharp/quiptools/build/temp.linux-x86_64-cpython-310/quiptools_e8p_gemv.o CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:822a1718cc8158b9b8311cdc714853d8adc22ce3ddaa5af7a85a8839fe3757a6
3
- size 5510384
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f237dcab2ca8896deac1088c1857f78235ba149f26b26b55a3761832fe08c1f6
3
+ size 5448600
quip-sharp/quiptools/build/temp.linux-x86_64-cpython-310/quiptools_wrapper.o CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:e9b169bc14cbe82bc31c5c9759c8cde3c6ee185919c42ef251d99dd3b8ca06ef
3
- size 6681584
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b7aa0a52748df5ca8b4e93a471bf7f3b491676c8c4a2be257855f08b1fa6c7f3
3
+ size 6729784
quip-sharp/quiptools/dist/quiptools_cuda-0.0.0-py3.10-linux-x86_64.egg CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:8393c6c6b07efe01541da956828ceca44efb79465dda3ea62b1e25b5c84297ed
3
- size 4193234
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:52d0c3eab47c2b08bed9da2d4a6e26c6ebb373192077b27eb4162f01d82a4237
3
+ size 4181938
quip-sharp/quiptools/quiptools_cuda.egg-info/SOURCES.txt CHANGED
@@ -1,12 +1,7 @@
1
- .quiptools.cu.swp
2
- benchmark_e8p.py
3
- error.txt
4
  quiptools.cu
5
  quiptools_e8p_gemv.cu
6
  quiptools_wrapper.cpp
7
  setup.py
8
- test_d4.py
9
- test_e8p.py
10
  quiptools_cuda.egg-info/PKG-INFO
11
  quiptools_cuda.egg-info/SOURCES.txt
12
  quiptools_cuda.egg-info/dependency_links.txt
 
 
 
 
1
  quiptools.cu
2
  quiptools_e8p_gemv.cu
3
  quiptools_wrapper.cpp
4
  setup.py
 
 
5
  quiptools_cuda.egg-info/PKG-INFO
6
  quiptools_cuda.egg-info/SOURCES.txt
7
  quiptools_cuda.egg-info/dependency_links.txt
quip-sharp/quiptools/quiptools_e8p_gemv.cu CHANGED
@@ -9,6 +9,8 @@
9
  #include <cuda_fp16.h>
10
  #include <mma.h>
11
 
 
 
12
  #include <ATen/ATen.h>
13
  #include <ATen/Context.h>
14
  #include <ATen/Dispatch.h>
@@ -40,222 +42,235 @@ __host__ static inline void gpuAssert(cudaError_t code, const char *file, int li
40
  }
41
  }
42
 
43
- #define BLOCK_SIZE 512
44
- #define WARP_SIZE 32
45
-
 
 
46
 
47
- __device__ static inline uint64_t decode8weights(
48
- uint16_t weight_compressed,
49
- const int64_t *__restrict__ codebook_abs
50
- ) {
51
 
52
- uint32_t bit_shift = (weight_compressed & 1)^1;
53
- uint8_t bits_sign = (weight_compressed >> 1) & ((1 << 7) - 1);
54
- uint8_t bits_abs = (weight_compressed >> 8) & ((1 << 9) - 1);
55
-
56
- int64_t packed_ = codebook_abs[bits_abs];
57
- uint32_t packed[2];
58
- memcpy(packed, &packed_, sizeof(packed));
59
-
60
- // TODO: optimize this by redefining the bit pattern
61
- uint32_t parity = __popc(packed[0] & 0x04040404) ^ __popc(packed[1]&0x04040404);
62
- uint8_t sign_vec = bits_sign | ((__popc(bits_sign) ^ parity) << 7);
63
- uint32_t decoded_sign[2];
64
- decoded_sign[0] = sign_vec * 0x08040201ll;
65
- decoded_sign[1] = sign_vec * 0x80402010ll;
66
- decoded_sign[0] &= 0x80808080;
67
- decoded_sign[1] &= 0x80808080;
68
- decoded_sign[0] >>= 7;
69
- decoded_sign[1] >>= 7;
70
- decoded_sign[0] *= 255 - 3;
71
- decoded_sign[1] *= 255 - 3;
72
- packed[0] ^= decoded_sign[0];
73
- packed[1] ^= decoded_sign[1];
74
- packed[0] |= 0x01010101;
75
- packed[1] |= 0x01010101;
76
- packed[0] -= bit_shift * 0x02020202;
77
- packed[1] -= bit_shift * 0x02020202;
78
-
79
- memcpy(&packed_, packed, sizeof(packed));
80
-
81
- return packed_;
82
  }
83
 
 
 
 
 
84
 
85
- /*
86
- llama 2 70B:
87
- M N K
88
- 1 8192 8192
89
- 1 57344 8192
90
- 1 8192 28672
91
- 1 10240 8192
92
- */
93
- template <typename scalar_t>
94
  __global__ static void
95
- __launch_bounds__(BLOCK_SIZE)
96
- decode_matmul_e8p_kernel(
97
- scalar_t *__restrict__ output,
98
- const scalar_t *__restrict__ x,
99
- const int16_t *__restrict__ weights_compressed,
100
- const int64_t *__restrict__ codebook_abs,
101
- int64_t M,
102
- int64_t N,
103
- int64_t K
104
  ) {
105
- __shared__ int64_t codebook_local[256];
106
- if (threadIdx.x < 256) {
107
- codebook_local[threadIdx.x] = codebook_abs[threadIdx.x];
108
- }
109
- __syncthreads();
110
-
111
- int64_t warpId = threadIdx.x / WARP_SIZE;
112
- int64_t laneId = threadIdx.x % WARP_SIZE;
113
-
114
- // each thread adds 8 activation-weight products
115
- const int64_t unroll_k = 2;
116
- const int64_t pack = 8;
117
- const int64_t elem_per_thread = pack * unroll_k;
118
- int64_t warps_per_elem = K / WARP_SIZE / elem_per_thread;
119
- const int64_t unroll_n = 16;
120
- const int64_t local_k = 1; // in terms of warp size. 32 threads of elem_per_thread fma each, dont set below 1 because of __shfl_down_sync
121
- int64_t local_n = BLOCK_SIZE / WARP_SIZE / local_k;
122
- int64_t grid_N = N / unroll_n;
123
-
124
- __shared__ scalar_t accum_scratch[BLOCK_SIZE / WARP_SIZE];
125
- bool SHARED_REDUCE = false;
126
-
127
- for (int64_t warpPos = blockIdx.x * BLOCK_SIZE/WARP_SIZE + warpId;
128
- warpPos < M * grid_N * warps_per_elem;
129
- warpPos += gridDim.x * BLOCK_SIZE/WARP_SIZE) {
130
-
131
- int64_t local_n_i = (warpPos% (BLOCK_SIZE / WARP_SIZE)) / local_k;
132
- int64_t local_k_i = (warpPos% (BLOCK_SIZE / WARP_SIZE)) % local_k;
133
- int64_t m = (warpPos / warps_per_elem) / (grid_N);
134
- int64_t k_ = warpPos % (warps_per_elem * local_n);
135
- int64_t k = k_ / (local_k * local_n) * local_k + k_ % local_k;
136
-
137
- scalar_t this_activations[elem_per_thread];
138
- #pragma unroll
139
- for (int64_t unroll_k_i = 0; unroll_k_i < unroll_k; unroll_k_i++) {
140
- const scalar_t *activations = x + m * K + (k * WARP_SIZE + laneId) * elem_per_thread + unroll_k_i * pack;
141
- if constexpr (std::is_same<scalar_t, float>::value) {
142
- const float4 *first_half = reinterpret_cast<const float4 *>(activations);
143
- __builtin_assume_aligned(first_half, 16);
144
- this_activations[unroll_k_i * pack + 0] = first_half->x;
145
- this_activations[unroll_k_i * pack + 1] = first_half->y;
146
- this_activations[unroll_k_i * pack + 2] = first_half->z;
147
- this_activations[unroll_k_i * pack + 3] = first_half->w;
148
- const float4 *second_half = reinterpret_cast<const float4 *>(activations + 4);
149
- __builtin_assume_aligned(second_half, 16);
150
- this_activations[unroll_k_i * pack + 4] = second_half->x;
151
- this_activations[unroll_k_i * pack + 5] = second_half->y;
152
- this_activations[unroll_k_i * pack + 6] = second_half->z;
153
- this_activations[unroll_k_i * pack + 7] = second_half->w;
154
- } else {
155
- for (int64_t activation_i = 0; activation_i < pack; activation_i++) {
156
- this_activations[unroll_k_i * pack + activation_i] = activations[activation_i];
157
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
158
  }
159
- }
160
- for (int64_t unroll_n_i = 0; unroll_n_i < unroll_n; unroll_n_i++) {
161
- scalar_t accumulator = 0;
162
- int64_t n = ((warpPos/local_k) % local_n) + ((warpPos / warps_per_elem) % grid_N) / local_n * local_n;
163
- __syncwarp();
164
- uint16_t this_weights[unroll_k];
165
- if (unroll_k % 2 == 0) {
166
- for (int64_t unroll_k_i = 0; unroll_k_i < unroll_k; unroll_k_i+=2) {
167
- const ushort2 *loaded = (const ushort2 *) &weights_compressed[(n*unroll_n + unroll_n_i) * K/pack + (k * WARP_SIZE + laneId) * unroll_k + unroll_k_i];
168
- __builtin_assume_aligned(loaded, 4);
169
- this_weights[unroll_k_i] = loaded->x;
170
- this_weights[unroll_k_i + 1] = loaded->y;
171
- }
172
- } else {
173
- for (int64_t unroll_k_i = 0; unroll_k_i < unroll_k; unroll_k_i++) {
174
- this_weights[unroll_k_i] = weights_compressed[(n*unroll_n + unroll_n_i) * K/pack + (k * WARP_SIZE + laneId) * unroll_k + unroll_k_i];
175
- }
176
- }
177
-
178
- #pragma unroll
179
- for (int64_t unroll_k_i = 0; unroll_k_i < unroll_k; unroll_k_i++) {
180
- // TODO: optimize access pattern by reordering weights
181
- uint16_t encoded = this_weights[unroll_k_i];
182
- uint64_t decoded = decode8weights(encoded, codebook_local);
183
-
184
- #ifdef EMULATED_INT82FP16
185
- // bit twiddling to convert int8 to fp16 from http://arxiv.org/abs/2211.10017
186
- half2 unpacked[2][2];
187
- uint64_t lower_half = decoded & 0x00ff00ff00ff00ff;
188
- lower_half = (lower_half ^ 0x6480648064806480);
189
- memcpy(unpacked[0], &lower_half, sizeof(uint64_t));
190
- uint64_t upper_half = (decoded & 0xff00ff00ff00ff00) >> 8;
191
- upper_half = (upper_half ^ 0x6480648064806480);
192
- memcpy(unpacked[1], &upper_half, sizeof(uint64_t));
193
-
194
- const half2 adjust = {__float2half(-1152.0f), __float2half(-1152.0f)};
195
- unpacked[0][0] = __hadd2(unpacked[0][0], adjust);
196
- unpacked[0][1] = __hadd2(unpacked[0][1], adjust);
197
- unpacked[1][0] = __hadd2(unpacked[1][0], adjust);
198
- unpacked[1][1] = __hadd2(unpacked[1][1], adjust);
199
-
200
- float2 unpacked_f[2][2];
201
- unpacked_f[0][0] = __half22float2(unpacked[0][0]);
202
- unpacked_f[0][1] = __half22float2(unpacked[0][1]);
203
- unpacked_f[1][0] = __half22float2(unpacked[1][0]);
204
- unpacked_f[1][1] = __half22float2(unpacked[1][1]);
205
-
206
-
207
- accumulator += this_activations[unroll_k_i * pack + 0] * (unpacked_f[0][0].x);
208
- accumulator += this_activations[unroll_k_i * pack + 1] * (unpacked_f[1][0].x);
209
- accumulator += this_activations[unroll_k_i * pack + 2] * (unpacked_f[0][0].y);
210
- accumulator += this_activations[unroll_k_i * pack + 3] * (unpacked_f[1][0].y);
211
- accumulator += this_activations[unroll_k_i * pack + 4] * (unpacked_f[0][1].x);
212
- accumulator += this_activations[unroll_k_i * pack + 5] * (unpacked_f[1][1].x);
213
- accumulator += this_activations[unroll_k_i * pack + 6] * (unpacked_f[0][1].y);
214
- accumulator += this_activations[unroll_k_i * pack + 7] * (unpacked_f[1][1].y);
215
- #else
216
- for (int64_t i = 0; i < 8; i += 1) {
217
- int8_t weight = decoded >> (i * 8);
218
- accumulator += this_activations[unroll_k_i * pack + i] * (int8_t) weight;
219
- }
220
- #endif
221
- }
222
- accumulator *= 0.25;
223
-
224
- for (int offset = WARP_SIZE/2; offset > 0; offset /= 2) {
225
- // apparently c10::Half does arithmetic operations in float32?
226
- // https://github.com/pytorch/pytorch/blob/0bd4d1f4ab38d3088de8aa5fbba35427b42d118e/c10/util/Half.h#L4C58-L6C80
227
- if constexpr (std::is_same<scalar_t, c10::Half>::value) {
228
- accumulator += __shfl_down_sync(0xFFFFFFFF, __float2half(accumulator), offset);
229
- } else {
230
- accumulator += __shfl_down_sync(0xFFFFFFFF, accumulator, offset);
231
- }
232
  }
 
233
 
234
- if (SHARED_REDUCE) {
235
- if (laneId == 0) {
236
- accum_scratch[warpId] = accumulator;
237
- __syncthreads();
238
- if (warpId % local_k == 0) {
239
- scalar_t local_accum = 0;
240
- for (int64_t accum_i = 0; accum_i < local_k; accum_i++) {
241
- local_accum += accum_scratch[warpId / local_k * local_k + accum_i];
242
- }
243
- atomicAdd(output + m * N + n * unroll_n + unroll_n_i, local_accum);
244
- }
245
- } else {
246
- __syncthreads();
247
- }
248
- } else {
249
- if (laneId == 0) {
250
- atomicAdd(output + m * N + n * unroll_n + unroll_n_i, accumulator);
251
- }
252
- }
253
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
254
  }
255
  }
256
 
257
 
258
- __host__ extern torch::Tensor decode_matmul_e8p(
259
  torch::Tensor x,
260
  torch::Tensor weights_compressed,
261
  torch::Tensor codebook_abs
@@ -265,47 +280,306 @@ __host__ extern torch::Tensor decode_matmul_e8p(
265
  CHECK_INPUT(weights_compressed);
266
  CHECK_INPUT(codebook_abs);
267
 
268
- TORCH_CHECK(weights_compressed.scalar_type() == torch::kInt16);
269
- TORCH_CHECK(codebook_abs.scalar_type() == torch::kInt64);
270
- TORCH_CHECK(x.size(-1) == weights_compressed.size(-1) << 3);
 
 
 
 
 
 
271
  TORCH_CHECK(codebook_abs.size(-1) == 256);
272
 
273
- int64_t M = x.size(-2);
274
- int64_t N = weights_compressed.size(-2);
275
  int64_t K = x.size(-1);
276
- //printf("%lld %lld %lld\n", M, N, K);
277
 
278
- TORCH_CHECK(K % WARP_SIZE == 0, "K is not divisible by WARP_SIZE");
 
 
 
 
279
 
280
  at::DeviceGuard guard(x.device());
281
  torch::TensorOptions options = torch::TensorOptions()
282
- .dtype(x.scalar_type())
283
  .layout(torch::kStrided)
284
  .device(torch::kCUDA)
285
  .requires_grad(false);
286
- torch::Tensor output = torch::zeros(std::vector<int64_t>{M, N}, options);
287
 
288
  cudaDeviceProp deviceProp;
289
  cudaGetDeviceProperties(&deviceProp, x.get_device());
290
- int64_t grid_size = static_cast<int64_t>(6 * deviceProp.multiProcessorCount);
291
  at::cuda::CUDAStream stream = at::cuda::getCurrentCUDAStream();
292
 
293
- AT_DISPATCH_FLOATING_TYPES_AND2(
294
- at::ScalarType::Half,
295
- at::ScalarType::BFloat16,
296
- x.scalar_type(),
297
- "decode_matmul_e8p",
298
- [&] {
299
- decode_matmul_e8p_kernel<<<grid_size, BLOCK_SIZE, 0, stream>>>(
300
- output.data_ptr<scalar_t>(),
301
- x.data_ptr<scalar_t>(),
302
- weights_compressed.data_ptr<int16_t>(),
303
- codebook_abs.data_ptr<int64_t>(),
304
- M,
305
- N,
306
- K);
307
- gpuErrchk(cudaPeekAtLastError());
308
- });
309
 
310
  return output;
311
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
  #include <cuda_fp16.h>
10
  #include <mma.h>
11
 
12
+ #include <cuda_pipeline.h>
13
+
14
  #include <ATen/ATen.h>
15
  #include <ATen/Context.h>
16
  #include <ATen/Dispatch.h>
 
42
  }
43
  }
44
 
45
+ __device__ static inline uint32_t add_as_half2(uint32_t x, uint32_t y) {
46
+ uint32_t z;
47
+ asm("add.f16x2 %0,%1,%2;" : "=r"(z) : "r"(x), "r"(y));
48
+ return z;
49
+ }
50
 
 
 
 
 
51
 
52
+ __device__ static inline uint32_t mask_lop3(uint32_t x, uint32_t m0, uint32_t m1) {
53
+ uint32_t y;
54
+ asm("lop3.b32 %0, %1, %2, %3, 0xEA;" : "=r"(y) : "r"(x), "r"(m0), "r"(m1));
55
+ return y;
56
+ // return (x & m0) | m1;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
57
  }
58
 
59
+ #define BASE_OFFSET 0xd080d080
60
+ #define XMASK 0x00f000f0
61
+ #define WMASK 0x50085008
62
+
63
 
 
 
 
 
 
 
 
 
 
64
  __global__ static void
65
+ // __launch_bounds__(1024, 1024)
66
+ decode_matvec_e8p_kernel(
67
+ float *__restrict__ output,
68
+ const uint2 *__restrict__ input,
69
+ const uint2 *__restrict__ weights_compressed,
70
+ const uint32_t *__restrict__ codebook_abs,
71
+ int N,
72
+ int K
 
73
  ) {
74
+ int warpId = threadIdx.y;
75
+ int laneId = threadIdx.x;
76
+
77
+ // __shared__ float sum_scratch[16*32];
78
+
79
+ // __shared__ uint32_t codebook_local[256*32];
80
+ // for (int icb = warpId; icb < 256; icb += 32) {
81
+ // codebook_local[icb*32 + laneId] = codebook_abs[icb];
82
+ // }
83
+ // __syncthreads();
84
+
85
+ __shared__ uint2 shared_weights[1024*2];
86
+
87
+ for (int iin = blockIdx.x; iin < (N >> 4); iin += gridDim.x) {
88
+
89
+ float z0 = 0.0;
90
+ float z1 = 0.0;
91
+ float z2 = 0.0;
92
+ float z3 = 0.0;
93
+
94
+ // int shwo = laneId + 32*warpId;
95
+
96
+ // __pipeline_memcpy_async(shared_weights + shwo, weights_compressed + laneId + 32*warpId + 1024*0 + (K >> 1)*iin, 8);
97
+ // __pipeline_commit();
98
+
99
+ for (int iik = warpId; iik < (K >> 6); iik += 32) {
100
+ // if (iik + 1 < (K >> 11)) {
101
+ // __pipeline_memcpy_async(shared_weights + (shwo ^ 1024), weights_compressed + laneId + 32*iik + 1024 + (K >> 1)*iin, 8);
102
+ // __pipeline_commit();
103
+ // __pipeline_wait_prior(1);
104
+ // shwo = shwo ^ 1024;
105
+ // }
106
+ // else {
107
+ // __pipeline_wait_prior(0);
108
+ // }
109
+
110
+ // uint2 w_compr = shared_weights[shwo]; // weights_compressed[laneId + 32*warpId + 1024*iik + (K >> 1)*iin];
111
+ uint2 w_compr = weights_compressed[laneId + 32*iik + (K >> 1)*iin];
112
+ uint32_t a = w_compr.x;
113
+ uint32_t b = w_compr.y;
114
+
115
+ uint32_t s = b;
116
+ s = s ^ (s >> 4);
117
+ s = s ^ (s >> 8);
118
+ s = s ^ (s >> 16);
119
+ uint32_t sb = (s & 15);
120
+ s = b ^ sb;
121
+ sb = sb | (sb << 16);
122
+
123
+ uint32_t input_to_warp = ((const uint32_t*)(&input[16*iik]))[laneId];
124
+ uint32_t shifted_laneId = (laneId & 3) << 3;
125
+
126
+ /// BLOCK 01
127
+ {
128
+ uint32_t x = codebook_abs[(a >> 0) & 255];
129
+ x = x ^ ((s & 0x11111111) * 14);
130
+
131
+ uint32_t o = BASE_OFFSET | ((sb & 0x00010001) << 4);
132
+
133
+ uint32_t w00 = add_as_half2(mask_lop3(x << 4, XMASK, WMASK), o);
134
+ uint32_t w01 = add_as_half2(mask_lop3(x << 0, XMASK, WMASK), o);
135
+ uint32_t w02 = add_as_half2(mask_lop3(x >> 4, XMASK, WMASK), o);
136
+ uint32_t w03 = add_as_half2(mask_lop3(x >> 8, XMASK, WMASK), o);
137
+
138
+ x = codebook_abs[(a >> 8) & 255];
139
+ x = x ^ ((s & 0x22222222) * 7);
140
+
141
+ o = BASE_OFFSET | ((sb & 0x00020002) << 3);
142
+
143
+ uint32_t w10 = add_as_half2(mask_lop3(x << 4, XMASK, WMASK), o);
144
+ uint32_t w11 = add_as_half2(mask_lop3(x << 0, XMASK, WMASK), o);
145
+ uint32_t w12 = add_as_half2(mask_lop3(x >> 4, XMASK, WMASK), o);
146
+ uint32_t w13 = add_as_half2(mask_lop3(x >> 8, XMASK, WMASK), o);
147
+
148
+ // uint2 x_in = input[0 + (laneId & 3)*4 + 16*warpId + 16*32*iik];
149
+ // uint32_t x_in0 = x_in.x;
150
+ // uint32_t x_in1 = x_in.y;
151
+
152
+ uint32_t x_in0 = __shfl_sync(FULL_MASK, input_to_warp, shifted_laneId | 0);
153
+ uint32_t x_in1 = __shfl_sync(FULL_MASK, input_to_warp, shifted_laneId | 1);
154
+
155
+ asm(
156
+ "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32"
157
+ " { %0, %1, %2, %3 },"
158
+ " { %4, %5, %6, %7 },"
159
+ " { %8, %9 },"
160
+ " { %0, %1, %2, %3 };"
161
+ : "+f"(z0), "+f"(z1), "+f"(z2), "+f"(z3)
162
+ : "r"(w00), "r"(w10), "r"(w01), "r"(w11),
163
+ "r"(x_in0), "r"(x_in1)
164
+ );
165
+
166
+
167
+ // x_in = input[1 + (laneId & 3)*4 + 16*warpId + 16*32*iik];
168
+ // x_in0 = x_in.x;
169
+ // x_in1 = x_in.y;
170
+
171
+ x_in0 = __shfl_sync(FULL_MASK, input_to_warp, shifted_laneId | 2);
172
+ x_in1 = __shfl_sync(FULL_MASK, input_to_warp, shifted_laneId | 3);
173
+
174
+ asm(
175
+ "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32"
176
+ " { %0, %1, %2, %3 },"
177
+ " { %4, %5, %6, %7 },"
178
+ " { %8, %9 },"
179
+ " { %0, %1, %2, %3 };"
180
+ : "+f"(z0), "+f"(z1), "+f"(z2), "+f"(z3)
181
+ : "r"(w02), "r"(w12), "r"(w03), "r"(w13),
182
+ "r"(x_in0), "r"(x_in1)
183
+ );
184
  }
185
+ /// BLOCK 23
186
+ {
187
+ uint32_t x = codebook_abs[(a >> 16) & 255];
188
+ s = s >> 2;
189
+ x = x ^ ((s & 0x11111111) * 14);
190
+
191
+ uint32_t o = BASE_OFFSET | ((sb & 0x00040004) << 2);
192
+
193
+ uint32_t w00 = add_as_half2(mask_lop3(x << 4, XMASK, WMASK), o);
194
+ uint32_t w01 = add_as_half2(mask_lop3(x << 0, XMASK, WMASK), o);
195
+ uint32_t w02 = add_as_half2(mask_lop3(x >> 4, XMASK, WMASK), o);
196
+ uint32_t w03 = add_as_half2(mask_lop3(x >> 8, XMASK, WMASK), o);
197
+
198
+ x = codebook_abs[(a >> 24) & 255];
199
+ x = x ^ ((s & 0x22222222) * 7);
200
+
201
+ o = BASE_OFFSET | ((sb & 0x00080008) << 1);
202
+
203
+ uint32_t w10 = add_as_half2(mask_lop3(x << 4, XMASK, WMASK), o);
204
+ uint32_t w11 = add_as_half2(mask_lop3(x << 0, XMASK, WMASK), o);
205
+ uint32_t w12 = add_as_half2(mask_lop3(x >> 4, XMASK, WMASK), o);
206
+ uint32_t w13 = add_as_half2(mask_lop3(x >> 8, XMASK, WMASK), o);
207
+
208
+
209
+ // uint2 x_in = input[2 + (laneId & 3)*4 + 16*warpId + 16*32*iik];
210
+ // uint32_t x_in0 = x_in.x;
211
+ // uint32_t x_in1 = x_in.y;
212
+
213
+ uint32_t x_in0 = __shfl_sync(FULL_MASK, input_to_warp, shifted_laneId | 4);
214
+ uint32_t x_in1 = __shfl_sync(FULL_MASK, input_to_warp, shifted_laneId | 5);
215
+
216
+ asm(
217
+ "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32"
218
+ " { %0, %1, %2, %3 },"
219
+ " { %4, %5, %6, %7 },"
220
+ " { %8, %9 },"
221
+ " { %0, %1, %2, %3 };"
222
+ : "+f"(z0), "+f"(z1), "+f"(z2), "+f"(z3)
223
+ : "r"(w00), "r"(w10), "r"(w01), "r"(w11),
224
+ "r"(x_in0), "r"(x_in1)
225
+ );
226
+
227
+
228
+ // x_in = input[3 + (laneId & 3)*4 + 16*warpId + 16*32*iik];
229
+ // x_in0 = x_in.x;
230
+ // x_in1 = x_in.y;
231
+
232
+ x_in0 = __shfl_sync(FULL_MASK, input_to_warp, shifted_laneId | 6);
233
+ x_in1 = __shfl_sync(FULL_MASK, input_to_warp, shifted_laneId | 7);
234
+
235
+ asm(
236
+ "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32"
237
+ " { %0, %1, %2, %3 },"
238
+ " { %4, %5, %6, %7 },"
239
+ " { %8, %9 },"
240
+ " { %0, %1, %2, %3 };"
241
+ : "+f"(z0), "+f"(z1), "+f"(z2), "+f"(z3)
242
+ : "r"(w02), "r"(w12), "r"(w03), "r"(w13),
243
+ "r"(x_in0), "r"(x_in1)
244
+ );
 
 
 
 
 
 
 
 
 
 
 
 
 
245
  }
246
+ }
247
 
248
+ // we produced 16 outputs, so only 16 threads
249
+ if ((laneId & 1) == 0) {
250
+ atomicAdd(output + (iin << 4) + (laneId >> 1), (laneId & 2) ? z2 : z0);
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
251
  }
252
+
253
+ // if ((laneId & 3) == 0) {
254
+ // sum_scratch[warpId + ((laneId >> 1) + 0) * 32] = z0;
255
+ // sum_scratch[warpId + ((laneId >> 1) + 1) * 32] = z2;
256
+ // }
257
+ // __syncthreads();
258
+
259
+ // // load and sum
260
+ // if (warpId < 16) {
261
+ // float acc = sum_scratch[laneId + warpId*32];
262
+ // for (int offset = 16; offset > 0; offset /= 2) {
263
+ // acc += __shfl_down_sync(FULL_MASK, acc, offset);
264
+ // }
265
+ // if (laneId == 0) {
266
+ // output[(iin << 4) + warpId] = acc;
267
+ // }
268
+ // }
269
  }
270
  }
271
 
272
 
273
+ __host__ extern torch::Tensor decode_matvec_e8p(
274
  torch::Tensor x,
275
  torch::Tensor weights_compressed,
276
  torch::Tensor codebook_abs
 
280
  CHECK_INPUT(weights_compressed);
281
  CHECK_INPUT(codebook_abs);
282
 
283
+ TORCH_CHECK(x.dim() == 1);
284
+ TORCH_CHECK(weights_compressed.dim() == 4);
285
+ TORCH_CHECK(weights_compressed.size(3) == 4);
286
+ TORCH_CHECK(weights_compressed.size(2) == 8);
287
+ TORCH_CHECK(codebook_abs.dim() == 1);
288
+ TORCH_CHECK(x.scalar_type() == torch::kFloat16);
289
+ TORCH_CHECK(weights_compressed.scalar_type() == torch::kInt64);
290
+ TORCH_CHECK(codebook_abs.scalar_type() == torch::kInt32);
291
+ TORCH_CHECK(x.size(-1) == weights_compressed.size(1) << 6);
292
  TORCH_CHECK(codebook_abs.size(-1) == 256);
293
 
294
+ int64_t N = weights_compressed.size(0) * 16;
 
295
  int64_t K = x.size(-1);
 
296
 
297
+ TORCH_CHECK(K % 64 == 0, "K is not divisible by 64");
298
+ TORCH_CHECK(N % 16 == 0, "N is not divisible by 16");
299
+
300
+ TORCH_CHECK(K < 65536, "K is not too large");
301
+ TORCH_CHECK(N < 65536, "N is not too large");
302
 
303
  at::DeviceGuard guard(x.device());
304
  torch::TensorOptions options = torch::TensorOptions()
305
+ .dtype(torch::kFloat32)
306
  .layout(torch::kStrided)
307
  .device(torch::kCUDA)
308
  .requires_grad(false);
309
+ torch::Tensor output = torch::zeros(std::vector<int64_t>{N}, options);
310
 
311
  cudaDeviceProp deviceProp;
312
  cudaGetDeviceProperties(&deviceProp, x.get_device());
313
+ int64_t grid_size = static_cast<int64_t>(deviceProp.multiProcessorCount);
314
  at::cuda::CUDAStream stream = at::cuda::getCurrentCUDAStream();
315
 
316
+ const dim3 block_size(32,32);
317
+
318
+ decode_matvec_e8p_kernel<<<grid_size, block_size, 0, stream>>>(
319
+ output.data_ptr<float>(),
320
+ (const uint2*)x.data_ptr<c10::Half>(),
321
+ (const uint2*)weights_compressed.data_ptr<int64_t>(),
322
+ (const uint32_t*)codebook_abs.data_ptr<int32_t>(),
323
+ N,
324
+ K);
325
+
326
+ gpuErrchk(cudaPeekAtLastError());
 
 
 
 
 
327
 
328
  return output;
329
  }
330
+
331
+
332
+
333
+ __global__ static void
334
+ test_tc_kernel(float *__restrict__ output) {
335
+ int laneId = threadIdx.x;
336
+
337
+ uint32_t w0 = (laneId == 0) ? 0x3C003C00 : 0x00000000;
338
+ uint32_t w1 = 0x00000000;
339
+ uint32_t w2 = 0x00000000;
340
+ uint32_t w3 = 0x00000000;
341
+
342
+ uint32_t x0 = (laneId == 0) ? 0x3C003C00 : 0x00000000;
343
+ uint32_t x1 = 0x00000000;
344
+
345
+ float z0 = 0.0;
346
+ float z1 = 0.0;
347
+ float z2 = 0.0;
348
+ float z3 = 0.0;
349
+
350
+ asm(
351
+ "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32"
352
+ " { %0, %1, %2, %3 },"
353
+ " { %4, %5, %6, %7 },"
354
+ " { %8, %9 },"
355
+ " { %0, %1, %2, %3 };"
356
+ : "+f"(z0), "+f"(z1), "+f"(z2), "+f"(z3)
357
+ : "r"(w0), "r"(w1), "r"(w2), "r"(w3),
358
+ "r"(x0), "r"(x1)
359
+ );
360
+
361
+ output[laneId*4 + 0] = z0;
362
+ output[laneId*4 + 1] = z1;
363
+ output[laneId*4 + 2] = z2;
364
+ output[laneId*4 + 3] = z3;
365
+ }
366
+
367
+ __host__ extern torch::Tensor test_tc() {
368
+
369
+ torch::TensorOptions options = torch::TensorOptions()
370
+ .dtype(torch::kFloat32)
371
+ .layout(torch::kStrided)
372
+ .device(torch::kCUDA)
373
+ .requires_grad(false);
374
+ torch::Tensor output = torch::zeros(std::vector<int64_t>{32*4}, options);
375
+
376
+ test_tc_kernel<<<1, 32>>>(output.data_ptr<float>());
377
+
378
+ gpuErrchk(cudaPeekAtLastError());
379
+
380
+ return output;
381
+ }
382
+
383
+
384
+
385
+
386
+ __global__ static void
387
+ test_codebook_expand_kernel(uint32_t *__restrict__ output, const uint32_t *__restrict__ codebook_abs) {
388
+ uint32_t a = threadIdx.x;
389
+ uint32_t b = 0;
390
+
391
+ for (int i = 0; i < 8; i++) {
392
+ b |= (((blockIdx.x >> i) & 1) << (4*i));
393
+ }
394
+
395
+ uint32_t s = b;
396
+ s = s ^ (s >> 4);
397
+ s = s ^ (s >> 8);
398
+ s = s ^ (s >> 16);
399
+ uint32_t sb = (s & 15);
400
+ s = b ^ sb;
401
+ sb = sb | (sb << 16);
402
+
403
+ uint32_t x = codebook_abs[(a >> 0) & 255];
404
+ x = x ^ ((s & 0x11111111) * 14);
405
+
406
+ uint32_t o = BASE_OFFSET | ((sb & 0x00010001) << 4);
407
+
408
+ uint32_t w0 = add_as_half2(mask_lop3(x << 4, XMASK, WMASK), o);
409
+ uint32_t w1 = add_as_half2(mask_lop3(x << 0, XMASK, WMASK), o);
410
+ uint32_t w2 = add_as_half2(mask_lop3(x >> 4, XMASK, WMASK), o);
411
+ uint32_t w3 = add_as_half2(mask_lop3(x >> 8, XMASK, WMASK), o);
412
+
413
+ output[blockIdx.x*256*4 + threadIdx.x*4 + 0] = w0;
414
+ output[blockIdx.x*256*4 + threadIdx.x*4 + 1] = w1;
415
+ output[blockIdx.x*256*4 + threadIdx.x*4 + 2] = w2;
416
+ output[blockIdx.x*256*4 + threadIdx.x*4 + 3] = w3;
417
+ }
418
+
419
+ __host__ extern torch::Tensor test_codebook_expand(torch::Tensor codebook_abs) {
420
+
421
+ torch::TensorOptions options = torch::TensorOptions()
422
+ .dtype(torch::kFloat16)
423
+ .layout(torch::kStrided)
424
+ .device(torch::kCUDA)
425
+ .requires_grad(false);
426
+ torch::Tensor output = torch::zeros(std::vector<int64_t>{256*256,8}, options);
427
+
428
+ test_codebook_expand_kernel<<<256, 256>>>((uint32_t*)output.data_ptr<c10::Half>(), (const uint32_t*)codebook_abs.data_ptr<int32_t>());
429
+
430
+ gpuErrchk(cudaPeekAtLastError());
431
+
432
+ return output;
433
+ }
434
+
435
+
436
+
437
+
438
+ __global__ static void
439
+ // __launch_bounds__(1024, 1024)
440
+ decompress_packed_e8p_kernel(
441
+ uint32_t *__restrict__ output,
442
+ const uint2 *__restrict__ weights_compressed,
443
+ const uint32_t *__restrict__ codebook_abs,
444
+ int N,
445
+ int K
446
+ ) {
447
+ int warpId = threadIdx.y;
448
+ int laneId = threadIdx.x;
449
+
450
+ for (int iin = blockIdx.x; iin < (N >> 4); iin += gridDim.x) {
451
+
452
+ for (int iik = warpId; iik < (K >> 6); iik += 32) {
453
+ uint2 w_compr = weights_compressed[laneId + 32*iik + (K >> 1)*iin];
454
+ uint32_t a = w_compr.x;
455
+ uint32_t b = w_compr.y;
456
+
457
+ uint32_t s = b;
458
+ s = s ^ (s >> 4);
459
+ s = s ^ (s >> 8);
460
+ s = s ^ (s >> 16);
461
+ uint32_t sb = (s & 15);
462
+ s = b ^ sb;
463
+ sb = sb | (sb << 16);
464
+
465
+ /// BLOCK 01
466
+ {
467
+ uint32_t x = codebook_abs[(a >> 0) & 255];
468
+ x = x ^ ((s & 0x11111111) * 14);
469
+
470
+ uint32_t o = BASE_OFFSET | ((sb & 0x00010001) << 4);
471
+
472
+ uint32_t w00 = add_as_half2(mask_lop3(x << 4, XMASK, WMASK), o);
473
+ uint32_t w01 = add_as_half2(mask_lop3(x << 0, XMASK, WMASK), o);
474
+ uint32_t w02 = add_as_half2(mask_lop3(x >> 4, XMASK, WMASK), o);
475
+ uint32_t w03 = add_as_half2(mask_lop3(x >> 8, XMASK, WMASK), o);
476
+
477
+ x = codebook_abs[(a >> 8) & 255];
478
+ x = x ^ ((s & 0x22222222) * 7);
479
+
480
+ o = BASE_OFFSET | ((sb & 0x00020002) << 3);
481
+
482
+ uint32_t w10 = add_as_half2(mask_lop3(x << 4, XMASK, WMASK), o);
483
+ uint32_t w11 = add_as_half2(mask_lop3(x << 0, XMASK, WMASK), o);
484
+ uint32_t w12 = add_as_half2(mask_lop3(x >> 4, XMASK, WMASK), o);
485
+ uint32_t w13 = add_as_half2(mask_lop3(x >> 8, XMASK, WMASK), o);
486
+
487
+ output[iin*8*K + (laneId >> 2)*K + 0 * (K >> 1) + iik*32 + 0*4 + ((laneId & 3) << 3) + 0] = w00;
488
+ output[iin*8*K + (laneId >> 2)*K + 0 * (K >> 1) + iik*32 + 0*4 + ((laneId & 3) << 3) + 1] = w01;
489
+ output[iin*8*K + (laneId >> 2)*K + 1 * (K >> 1) + iik*32 + 0*4 + ((laneId & 3) << 3) + 0] = w10;
490
+ output[iin*8*K + (laneId >> 2)*K + 1 * (K >> 1) + iik*32 + 0*4 + ((laneId & 3) << 3) + 1] = w11;
491
+
492
+ output[iin*8*K + (laneId >> 2)*K + 0 * (K >> 1) + iik*32 + 0*4 + ((laneId & 3) << 3) + 2] = w02;
493
+ output[iin*8*K + (laneId >> 2)*K + 0 * (K >> 1) + iik*32 + 0*4 + ((laneId & 3) << 3) + 3] = w03;
494
+ output[iin*8*K + (laneId >> 2)*K + 1 * (K >> 1) + iik*32 + 0*4 + ((laneId & 3) << 3) + 2] = w12;
495
+ output[iin*8*K + (laneId >> 2)*K + 1 * (K >> 1) + iik*32 + 0*4 + ((laneId & 3) << 3) + 3] = w13;
496
+
497
+ }
498
+ /// BLOCK 23
499
+ {
500
+ uint32_t x = codebook_abs[(a >> 16) & 255];
501
+ s = s >> 2;
502
+ x = x ^ ((s & 0x11111111) * 14);
503
+
504
+ uint32_t o = BASE_OFFSET | ((sb & 0x00040004) << 2);
505
+
506
+ uint32_t w00 = add_as_half2(mask_lop3(x << 4, XMASK, WMASK), o);
507
+ uint32_t w01 = add_as_half2(mask_lop3(x << 0, XMASK, WMASK), o);
508
+ uint32_t w02 = add_as_half2(mask_lop3(x >> 4, XMASK, WMASK), o);
509
+ uint32_t w03 = add_as_half2(mask_lop3(x >> 8, XMASK, WMASK), o);
510
+
511
+ x = codebook_abs[(a >> 24) & 255];
512
+ x = x ^ ((s & 0x22222222) * 7);
513
+
514
+ o = BASE_OFFSET | ((sb & 0x00080008) << 1);
515
+
516
+ uint32_t w10 = add_as_half2(mask_lop3(x << 4, XMASK, WMASK), o);
517
+ uint32_t w11 = add_as_half2(mask_lop3(x << 0, XMASK, WMASK), o);
518
+ uint32_t w12 = add_as_half2(mask_lop3(x >> 4, XMASK, WMASK), o);
519
+ uint32_t w13 = add_as_half2(mask_lop3(x >> 8, XMASK, WMASK), o);
520
+
521
+ output[iin*8*K + (laneId >> 2)*K + 0 * (K >> 1) + iik*32 + 1*4 + ((laneId & 3) << 3) + 0] = w00;
522
+ output[iin*8*K + (laneId >> 2)*K + 0 * (K >> 1) + iik*32 + 1*4 + ((laneId & 3) << 3) + 1] = w01;
523
+ output[iin*8*K + (laneId >> 2)*K + 1 * (K >> 1) + iik*32 + 1*4 + ((laneId & 3) << 3) + 0] = w10;
524
+ output[iin*8*K + (laneId >> 2)*K + 1 * (K >> 1) + iik*32 + 1*4 + ((laneId & 3) << 3) + 1] = w11;
525
+
526
+ output[iin*8*K + (laneId >> 2)*K + 0 * (K >> 1) + iik*32 + 1*4 + ((laneId & 3) << 3) + 2] = w02;
527
+ output[iin*8*K + (laneId >> 2)*K + 0 * (K >> 1) + iik*32 + 1*4 + ((laneId & 3) << 3) + 3] = w03;
528
+ output[iin*8*K + (laneId >> 2)*K + 1 * (K >> 1) + iik*32 + 1*4 + ((laneId & 3) << 3) + 2] = w12;
529
+ output[iin*8*K + (laneId >> 2)*K + 1 * (K >> 1) + iik*32 + 1*4 + ((laneId & 3) << 3) + 3] = w13;
530
+ }
531
+ }
532
+ }
533
+ }
534
+
535
+
536
+ __host__ extern torch::Tensor decompress_packed_e8p(
537
+ torch::Tensor weights_compressed,
538
+ torch::Tensor codebook_abs
539
+ ) {
540
+ CHECK_INPUT(weights_compressed);
541
+ CHECK_INPUT(codebook_abs);
542
+
543
+ TORCH_CHECK(weights_compressed.dim() == 4);
544
+ TORCH_CHECK(weights_compressed.size(3) == 4);
545
+ TORCH_CHECK(weights_compressed.size(2) == 8);
546
+ TORCH_CHECK(codebook_abs.dim() == 1);
547
+ TORCH_CHECK(weights_compressed.scalar_type() == torch::kInt64);
548
+ TORCH_CHECK(codebook_abs.scalar_type() == torch::kInt32);
549
+ TORCH_CHECK(codebook_abs.size(-1) == 256);
550
+
551
+ int64_t N = weights_compressed.size(0) * 16;
552
+ int64_t K = weights_compressed.size(1) << 6;
553
+
554
+ TORCH_CHECK(K % 64 == 0, "K is not divisible by 64");
555
+ TORCH_CHECK(N % 16 == 0, "N is not divisible by 16");
556
+
557
+ TORCH_CHECK(K < 65536, "K is not too large");
558
+ TORCH_CHECK(N < 65536, "N is not too large");
559
+
560
+ at::DeviceGuard guard(codebook_abs.device());
561
+ torch::TensorOptions options = torch::TensorOptions()
562
+ .dtype(torch::kFloat16)
563
+ .layout(torch::kStrided)
564
+ .device(torch::kCUDA)
565
+ .requires_grad(false);
566
+ torch::Tensor output = torch::zeros(std::vector<int64_t>{N,K}, options);
567
+
568
+ cudaDeviceProp deviceProp;
569
+ cudaGetDeviceProperties(&deviceProp, weights_compressed.get_device());
570
+ int64_t grid_size = static_cast<int64_t>(deviceProp.multiProcessorCount);
571
+ at::cuda::CUDAStream stream = at::cuda::getCurrentCUDAStream();
572
+
573
+ const dim3 block_size(32,32);
574
+
575
+ decompress_packed_e8p_kernel<<<grid_size, block_size, 0, stream>>>(
576
+ (uint32_t*)output.data_ptr<c10::Half>(),
577
+ (const uint2*)weights_compressed.data_ptr<int64_t>(),
578
+ (const uint32_t*)codebook_abs.data_ptr<int32_t>(),
579
+ N,
580
+ K);
581
+
582
+ gpuErrchk(cudaPeekAtLastError());
583
+
584
+ return output;
585
+ }
quip-sharp/quiptools/quiptools_wrapper.cpp CHANGED
@@ -43,13 +43,17 @@ void decompress_e8p_origorder(
43
  torch::Tensor &Y // m x n
44
  );
45
 
46
- torch::Tensor decode_matmul_e8p(
 
 
 
 
 
47
  torch::Tensor x,
48
  torch::Tensor weights_compressed,
49
  torch::Tensor codebook_abs
50
  );
51
 
52
-
53
  void decompress_hi4b1c_packed(
54
  torch::Tensor YIs, // m x (n/8)
55
  torch::Tensor CB, // 16 x 1
@@ -64,7 +68,8 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
64
  m.def("decompress_d4", &decompress_d4, "decompress_d4");
65
  m.def("decompress_d4_origorder", &decompress_d4_origorder, "decompress_d4_origorder");
66
  m.def("decompress_e8p_origorder", &decompress_e8p_origorder, "decompress_e8p_origorder");
67
- m.def("decode_matmul_e8p", &decode_matmul_e8p, "decode_matmul_e8p");
 
68
  m.def("decompress_hi4b1c_packed", &decompress_hi4b1c_packed, "decompress_hi4b1c_packed");
69
  }
70
 
 
43
  torch::Tensor &Y // m x n
44
  );
45
 
46
+ torch::Tensor decompress_packed_e8p(
47
+ torch::Tensor weights_compressed, // m x (n/8)
48
+ torch::Tensor codebook_abs // 256 x 8
49
+ );
50
+
51
+ torch::Tensor decode_matvec_e8p(
52
  torch::Tensor x,
53
  torch::Tensor weights_compressed,
54
  torch::Tensor codebook_abs
55
  );
56
 
 
57
  void decompress_hi4b1c_packed(
58
  torch::Tensor YIs, // m x (n/8)
59
  torch::Tensor CB, // 16 x 1
 
68
  m.def("decompress_d4", &decompress_d4, "decompress_d4");
69
  m.def("decompress_d4_origorder", &decompress_d4_origorder, "decompress_d4_origorder");
70
  m.def("decompress_e8p_origorder", &decompress_e8p_origorder, "decompress_e8p_origorder");
71
+ m.def("decompress_packed_e8p", &decompress_packed_e8p, "decompress_packed_e8p");
72
+ m.def("decode_matvec_e8p", &decode_matvec_e8p, "decode_matvec_e8p");
73
  m.def("decompress_hi4b1c_packed", &decompress_hi4b1c_packed, "decompress_hi4b1c_packed");
74
  }
75
 
quip-sharp/scripts/upload_hf.py CHANGED
@@ -29,4 +29,5 @@ if __name__ == "__main__":
29
  multi_commits=args.no_multi_commits,
30
  multi_commits_verbose=True,
31
  token=args.write_token,
 
32
  )
 
29
  multi_commits=args.no_multi_commits,
30
  multi_commits_verbose=True,
31
  token=args.write_token,
32
+ create_pr=True, # creates a PR. You must manually merge the PR in
33
  )