File size: 30,630 Bytes
44de051
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
{
 "cells": [
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "This file is a space to construct functions, experiment and see changes directly instead of having to reload the app everytime. It serves as the draft for app.py and contains similar functions except for the streamlit app component"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import streamlit as st\n",
    "from fastai.collab import *\n",
    "import torch\n",
    "from torch import nn\n",
    "import pickle\n",
    "import pandas as pd\n",
    "from transformers import PegasusForConditionalGeneration, PegasusTokenizer\n",
    "import sentencepiece\n",
    "import string\n",
    "import requests"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 89,
   "metadata": {},
   "outputs": [
    {
     "ename": "KeyboardInterrupt",
     "evalue": "",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mKeyboardInterrupt\u001b[0m                         Traceback (most recent call last)",
      "Cell \u001b[0;32mIn[89], line 5\u001b[0m\n\u001b[1;32m      2\u001b[0m dls\u001b[39m=\u001b[39m pd\u001b[39m.\u001b[39mread_pickle(\u001b[39m'\u001b[39m\u001b[39mdataloader.pkl\u001b[39m\u001b[39m'\u001b[39m)\n\u001b[1;32m      4\u001b[0m \u001b[39m# Create an instance of the model\u001b[39;00m\n\u001b[0;32m----> 5\u001b[0m learn \u001b[39m=\u001b[39m collab_learner(dls, use_nn\u001b[39m=\u001b[39;49m\u001b[39mTrue\u001b[39;49;00m,layers\u001b[39m=\u001b[39;49m[\u001b[39m20\u001b[39;49m,\u001b[39m10\u001b[39;49m],y_range\u001b[39m=\u001b[39;49m(\u001b[39m0\u001b[39;49m,\u001b[39m10.5\u001b[39;49m))\n\u001b[1;32m      7\u001b[0m \u001b[39m# Load the saved state dictionary\u001b[39;00m\n\u001b[1;32m      8\u001b[0m state_dict \u001b[39m=\u001b[39m torch\u001b[39m.\u001b[39mload(\u001b[39m'\u001b[39m\u001b[39mmyModel.pth\u001b[39m\u001b[39m'\u001b[39m,map_location\u001b[39m=\u001b[39mtorch\u001b[39m.\u001b[39mdevice(\u001b[39m'\u001b[39m\u001b[39mcpu\u001b[39m\u001b[39m'\u001b[39m))\n",
      "File \u001b[0;32m~/mambaforge/lib/python3.10/site-packages/fastai/collab.py:100\u001b[0m, in \u001b[0;36mcollab_learner\u001b[0;34m(dls, n_factors, use_nn, emb_szs, layers, config, y_range, loss_func, **kwargs)\u001b[0m\n\u001b[1;32m     98\u001b[0m \u001b[39mif\u001b[39;00m y_range \u001b[39mis\u001b[39;00m \u001b[39mnot\u001b[39;00m \u001b[39mNone\u001b[39;00m: config[\u001b[39m'\u001b[39m\u001b[39my_range\u001b[39m\u001b[39m'\u001b[39m] \u001b[39m=\u001b[39m y_range\n\u001b[1;32m     99\u001b[0m \u001b[39mif\u001b[39;00m layers \u001b[39mis\u001b[39;00m \u001b[39mNone\u001b[39;00m: layers \u001b[39m=\u001b[39m [n_factors]\n\u001b[0;32m--> 100\u001b[0m \u001b[39mif\u001b[39;00m use_nn: model \u001b[39m=\u001b[39m EmbeddingNN(emb_szs\u001b[39m=\u001b[39;49memb_szs, layers\u001b[39m=\u001b[39;49mlayers, \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mconfig)\n\u001b[1;32m    101\u001b[0m \u001b[39melse\u001b[39;00m:      model \u001b[39m=\u001b[39m EmbeddingDotBias\u001b[39m.\u001b[39mfrom_classes(n_factors, dls\u001b[39m.\u001b[39mclasses, y_range\u001b[39m=\u001b[39my_range)\n\u001b[1;32m    102\u001b[0m \u001b[39mreturn\u001b[39;00m Learner(dls, model, loss_func\u001b[39m=\u001b[39mloss_func, \u001b[39m*\u001b[39m\u001b[39m*\u001b[39mkwargs)\n",
      "File \u001b[0;32m~/mambaforge/lib/python3.10/site-packages/fastcore/meta.py:40\u001b[0m, in \u001b[0;36mPrePostInitMeta.__call__\u001b[0;34m(cls, *args, **kwargs)\u001b[0m\n\u001b[1;32m     38\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mtype\u001b[39m(res)\u001b[39m==\u001b[39m\u001b[39mcls\u001b[39m:\n\u001b[1;32m     39\u001b[0m     \u001b[39mif\u001b[39;00m \u001b[39mhasattr\u001b[39m(res,\u001b[39m'\u001b[39m\u001b[39m__pre_init__\u001b[39m\u001b[39m'\u001b[39m): res\u001b[39m.\u001b[39m__pre_init__(\u001b[39m*\u001b[39margs,\u001b[39m*\u001b[39m\u001b[39m*\u001b[39mkwargs)\n\u001b[0;32m---> 40\u001b[0m     res\u001b[39m.\u001b[39;49m\u001b[39m__init__\u001b[39;49m(\u001b[39m*\u001b[39;49margs,\u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mkwargs)\n\u001b[1;32m     41\u001b[0m     \u001b[39mif\u001b[39;00m \u001b[39mhasattr\u001b[39m(res,\u001b[39m'\u001b[39m\u001b[39m__post_init__\u001b[39m\u001b[39m'\u001b[39m): res\u001b[39m.\u001b[39m__post_init__(\u001b[39m*\u001b[39margs,\u001b[39m*\u001b[39m\u001b[39m*\u001b[39mkwargs)\n\u001b[1;32m     42\u001b[0m \u001b[39mreturn\u001b[39;00m res\n",
      "File \u001b[0;32m~/mambaforge/lib/python3.10/site-packages/fastai/collab.py:89\u001b[0m, in \u001b[0;36mEmbeddingNN.__init__\u001b[0;34m(self, emb_szs, layers, **kwargs)\u001b[0m\n\u001b[1;32m     87\u001b[0m \u001b[39m@delegates\u001b[39m(TabularModel\u001b[39m.\u001b[39m\u001b[39m__init__\u001b[39m)\n\u001b[1;32m     88\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39m__init__\u001b[39m(\u001b[39mself\u001b[39m, emb_szs, layers, \u001b[39m*\u001b[39m\u001b[39m*\u001b[39mkwargs):\n\u001b[0;32m---> 89\u001b[0m     \u001b[39msuper\u001b[39;49m()\u001b[39m.\u001b[39;49m\u001b[39m__init__\u001b[39;49m(emb_szs\u001b[39m=\u001b[39;49memb_szs, n_cont\u001b[39m=\u001b[39;49m\u001b[39m0\u001b[39;49m, out_sz\u001b[39m=\u001b[39;49m\u001b[39m1\u001b[39;49m, layers\u001b[39m=\u001b[39;49mlayers, \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mkwargs)\n",
      "File \u001b[0;32m~/mambaforge/lib/python3.10/site-packages/fastai/tabular/model.py:53\u001b[0m, in \u001b[0;36mTabularModel.__init__\u001b[0;34m(self, emb_szs, n_cont, out_sz, layers, ps, embed_p, y_range, use_bn, bn_final, bn_cont, act_cls, lin_first)\u001b[0m\n\u001b[1;32m     51\u001b[0m ps \u001b[39m=\u001b[39m ifnone(ps, [\u001b[39m0\u001b[39m]\u001b[39m*\u001b[39m\u001b[39mlen\u001b[39m(layers))\n\u001b[1;32m     52\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mnot\u001b[39;00m is_listy(ps): ps \u001b[39m=\u001b[39m [ps]\u001b[39m*\u001b[39m\u001b[39mlen\u001b[39m(layers)\n\u001b[0;32m---> 53\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39membeds \u001b[39m=\u001b[39m nn\u001b[39m.\u001b[39mModuleList([Embedding(ni, nf) \u001b[39mfor\u001b[39;00m ni,nf \u001b[39min\u001b[39;00m emb_szs])\n\u001b[1;32m     54\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39memb_drop \u001b[39m=\u001b[39m nn\u001b[39m.\u001b[39mDropout(embed_p)\n\u001b[1;32m     55\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mbn_cont \u001b[39m=\u001b[39m nn\u001b[39m.\u001b[39mBatchNorm1d(n_cont) \u001b[39mif\u001b[39;00m bn_cont \u001b[39melse\u001b[39;00m \u001b[39mNone\u001b[39;00m\n",
      "File \u001b[0;32m~/mambaforge/lib/python3.10/site-packages/fastai/tabular/model.py:53\u001b[0m, in \u001b[0;36m<listcomp>\u001b[0;34m(.0)\u001b[0m\n\u001b[1;32m     51\u001b[0m ps \u001b[39m=\u001b[39m ifnone(ps, [\u001b[39m0\u001b[39m]\u001b[39m*\u001b[39m\u001b[39mlen\u001b[39m(layers))\n\u001b[1;32m     52\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mnot\u001b[39;00m is_listy(ps): ps \u001b[39m=\u001b[39m [ps]\u001b[39m*\u001b[39m\u001b[39mlen\u001b[39m(layers)\n\u001b[0;32m---> 53\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39membeds \u001b[39m=\u001b[39m nn\u001b[39m.\u001b[39mModuleList([Embedding(ni, nf) \u001b[39mfor\u001b[39;00m ni,nf \u001b[39min\u001b[39;00m emb_szs])\n\u001b[1;32m     54\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39memb_drop \u001b[39m=\u001b[39m nn\u001b[39m.\u001b[39mDropout(embed_p)\n\u001b[1;32m     55\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mbn_cont \u001b[39m=\u001b[39m nn\u001b[39m.\u001b[39mBatchNorm1d(n_cont) \u001b[39mif\u001b[39;00m bn_cont \u001b[39melse\u001b[39;00m \u001b[39mNone\u001b[39;00m\n",
      "File \u001b[0;32m~/mambaforge/lib/python3.10/site-packages/fastai/layers.py:291\u001b[0m, in \u001b[0;36mEmbedding.__init__\u001b[0;34m(self, ni, nf, std)\u001b[0m\n\u001b[1;32m    290\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39m__init__\u001b[39m(\u001b[39mself\u001b[39m, ni, nf, std\u001b[39m=\u001b[39m\u001b[39m0.01\u001b[39m):\n\u001b[0;32m--> 291\u001b[0m     \u001b[39msuper\u001b[39;49m()\u001b[39m.\u001b[39;49m\u001b[39m__init__\u001b[39;49m(ni, nf)\n\u001b[1;32m    292\u001b[0m     trunc_normal_(\u001b[39mself\u001b[39m\u001b[39m.\u001b[39mweight\u001b[39m.\u001b[39mdata, std\u001b[39m=\u001b[39mstd)\n",
      "File \u001b[0;32m~/mambaforge/lib/python3.10/site-packages/torch/nn/modules/sparse.py:142\u001b[0m, in \u001b[0;36mEmbedding.__init__\u001b[0;34m(self, num_embeddings, embedding_dim, padding_idx, max_norm, norm_type, scale_grad_by_freq, sparse, _weight, device, dtype)\u001b[0m\n\u001b[1;32m    140\u001b[0m \u001b[39mif\u001b[39;00m _weight \u001b[39mis\u001b[39;00m \u001b[39mNone\u001b[39;00m:\n\u001b[1;32m    141\u001b[0m     \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mweight \u001b[39m=\u001b[39m Parameter(torch\u001b[39m.\u001b[39mempty((num_embeddings, embedding_dim), \u001b[39m*\u001b[39m\u001b[39m*\u001b[39mfactory_kwargs))\n\u001b[0;32m--> 142\u001b[0m     \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mreset_parameters()\n\u001b[1;32m    143\u001b[0m \u001b[39melse\u001b[39;00m:\n\u001b[1;32m    144\u001b[0m     \u001b[39massert\u001b[39;00m \u001b[39mlist\u001b[39m(_weight\u001b[39m.\u001b[39mshape) \u001b[39m==\u001b[39m [num_embeddings, embedding_dim], \\\n\u001b[1;32m    145\u001b[0m         \u001b[39m'\u001b[39m\u001b[39mShape of weight does not match num_embeddings and embedding_dim\u001b[39m\u001b[39m'\u001b[39m\n",
      "File \u001b[0;32m~/mambaforge/lib/python3.10/site-packages/torch/nn/modules/sparse.py:151\u001b[0m, in \u001b[0;36mEmbedding.reset_parameters\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m    150\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39mreset_parameters\u001b[39m(\u001b[39mself\u001b[39m) \u001b[39m-\u001b[39m\u001b[39m>\u001b[39m \u001b[39mNone\u001b[39;00m:\n\u001b[0;32m--> 151\u001b[0m     init\u001b[39m.\u001b[39;49mnormal_(\u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mweight)\n\u001b[1;32m    152\u001b[0m     \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_fill_padding_idx_with_zero()\n",
      "File \u001b[0;32m~/mambaforge/lib/python3.10/site-packages/torch/nn/init.py:155\u001b[0m, in \u001b[0;36mnormal_\u001b[0;34m(tensor, mean, std)\u001b[0m\n\u001b[1;32m    153\u001b[0m \u001b[39mif\u001b[39;00m torch\u001b[39m.\u001b[39moverrides\u001b[39m.\u001b[39mhas_torch_function_variadic(tensor):\n\u001b[1;32m    154\u001b[0m     \u001b[39mreturn\u001b[39;00m torch\u001b[39m.\u001b[39moverrides\u001b[39m.\u001b[39mhandle_torch_function(normal_, (tensor,), tensor\u001b[39m=\u001b[39mtensor, mean\u001b[39m=\u001b[39mmean, std\u001b[39m=\u001b[39mstd)\n\u001b[0;32m--> 155\u001b[0m \u001b[39mreturn\u001b[39;00m _no_grad_normal_(tensor, mean, std)\n",
      "File \u001b[0;32m~/mambaforge/lib/python3.10/site-packages/torch/nn/init.py:19\u001b[0m, in \u001b[0;36m_no_grad_normal_\u001b[0;34m(tensor, mean, std)\u001b[0m\n\u001b[1;32m     17\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39m_no_grad_normal_\u001b[39m(tensor, mean, std):\n\u001b[1;32m     18\u001b[0m     \u001b[39mwith\u001b[39;00m torch\u001b[39m.\u001b[39mno_grad():\n\u001b[0;32m---> 19\u001b[0m         \u001b[39mreturn\u001b[39;00m tensor\u001b[39m.\u001b[39;49mnormal_(mean, std)\n",
      "\u001b[0;31mKeyboardInterrupt\u001b[0m: "
     ]
    }
   ],
   "source": [
    "# Load the data loader \n",
    "dls= pd.read_pickle('dataloader.pkl')\n",
    "\n",
    "# Create an instance of the model\n",
    "learn = collab_learner(dls, use_nn=True,layers=[20,10],y_range=(0,10.5))\n",
    "\n",
    "# Load the saved state dictionary\n",
    "state_dict = torch.load('myModel.pth',map_location=torch.device('cpu'))\n",
    "\n",
    "# Assign the loaded state dictionary to the model's load_state_dict() method\n",
    "learn.model.load_state_dict(state_dict)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_3_recs(book):\n",
    "  book_factors = learn.model.embeds[1].weight\n",
    "  idx = dls.classes['title'].o2i[book]\n",
    "  distances = nn.CosineSimilarity(dim=1)(book_factors, book_factors[idx][None])\n",
    "  idxs = distances.argsort(descending=True)[1:4]\n",
    "  recs = [dls.classes['title'][i] for i in idxs]\n",
    "  return recs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "#load books dataframe\n",
    "books_df = pd.read_csv('./data/BX_Books.csv', sep=';',encoding='latin-1')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['http://images.amazon.com/images/P/0451524934.01.LZZZZZZZ.jpg',\n",
       " 'http://images.amazon.com/images/P/185326041X.01.LZZZZZZZ.jpg']"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "#get covers\n",
    "def get_covers(recs):\n",
    "  imgs = [books_df[books_df['Book-Title']==r]['Image-URL-L'].tolist()[0]for r in recs]\n",
    "  return imgs\n",
    "\n",
    "get_covers(['1984', 'The Great Gatsby'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2023-06-24 16:15:04.552 \n",
      "  \u001b[33m\u001b[1mWarning:\u001b[0m to view this Streamlit app on a browser, run it with the following\n",
      "  command:\n",
      "\n",
      "    streamlit run /Users/irenenguyen/mambaforge/lib/python3.10/site-packages/ipykernel_launcher.py [ARGUMENTS]\n"
     ]
    }
   ],
   "source": [
    "user_input = st.text_input(\"What's your favorite book?\")\n",
    "recs = get_3_recs(user_input)\n",
    "st.write(\"Try these books:\", \",\".join(recs))"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Description Summarizer"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Getting book description from Google Books API"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def search_book_description(title):\n",
    "    # Google Books API endpoint for book search\n",
    "    url = \"https://www.googleapis.com/books/v1/volumes\"\n",
    "\n",
    "    # Parameters for the book search\n",
    "    params = {\n",
    "        \"q\": title,\n",
    "        \"maxResults\": 1\n",
    "    }\n",
    "\n",
    "    # Send GET request to Google Books API\n",
    "    response = requests.get(url, params=params)\n",
    "\n",
    "    # Check if the request was successful\n",
    "    if response.status_code == 200:\n",
    "        # Parse the JSON response to extract the book description\n",
    "        data = response.json()\n",
    "\n",
    "        if \"items\" in data and len(data[\"items\"]) > 0:\n",
    "            book_description = data[\"items\"][0][\"volumeInfo\"].get(\"description\", \"No description available.\")\n",
    "            return book_description\n",
    "        else:\n",
    "            print(\"No book found with the given title.\")\n",
    "            return None\n",
    "    else:\n",
    "        # If the request failed, print the error message\n",
    "        print(\"Error:\", response.status_code, response.text)\n",
    "        return None"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Summarization Model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "You are using a model of type pegasus_x to instantiate a model of type pegasus. This is not supported for all configurations of models and can yield errors.\n",
      "Some weights of the model checkpoint at pszemraj/pegasus-x-large-book-summary were not used when initializing PegasusForConditionalGeneration: ['model.encoder.layers.5.global_self_attn_layer_norm.bias', 'model.encoder.layers.2.global_self_attn_layer_norm.weight', 'model.encoder.layers.15.global_self_attn_layer_norm.weight', 'model.encoder.layers.11.global_self_attn_layer_norm.weight', 'model.encoder.layers.12.global_self_attn_layer_norm.bias', 'model.encoder.layers.4.global_self_attn_layer_norm.weight', 'model.encoder.layers.0.global_self_attn_layer_norm.bias', 'model.encoder.layers.10.global_self_attn_layer_norm.bias', 'model.encoder.layers.5.global_self_attn_layer_norm.weight', 'model.encoder.layers.7.global_self_attn_layer_norm.bias', 'model.encoder.layers.11.global_self_attn_layer_norm.bias', 'model.encoder.layers.13.global_self_attn_layer_norm.bias', 'model.encoder.layers.13.global_self_attn_layer_norm.weight', 'model.encoder.layers.14.global_self_attn_layer_norm.weight', 'model.encoder.layers.9.global_self_attn_layer_norm.weight', 'model.encoder.layers.8.global_self_attn_layer_norm.weight', 'model.encoder.layers.3.global_self_attn_layer_norm.weight', 'model.encoder.layers.4.global_self_attn_layer_norm.bias', 'model.encoder.layers.14.global_self_attn_layer_norm.bias', 'model.encoder.layers.10.global_self_attn_layer_norm.weight', 'model.encoder.layers.6.global_self_attn_layer_norm.bias', 'model.encoder.layers.2.global_self_attn_layer_norm.bias', 'model.encoder.layers.1.global_self_attn_layer_norm.weight', 'model.encoder.layers.9.global_self_attn_layer_norm.bias', 'model.encoder.layers.0.global_self_attn_layer_norm.weight', 'model.encoder.layers.12.global_self_attn_layer_norm.weight', 'model.encoder.layers.7.global_self_attn_layer_norm.weight', 'model.encoder.layers.3.global_self_attn_layer_norm.bias', 'model.encoder.layers.6.global_self_attn_layer_norm.weight', 'model.encoder.layers.15.global_self_attn_layer_norm.bias', 'model.encoder.embed_global.weight', 'model.encoder.layers.1.global_self_attn_layer_norm.bias', 'model.encoder.layers.8.global_self_attn_layer_norm.bias']\n",
      "- This IS expected if you are initializing PegasusForConditionalGeneration from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
      "- This IS NOT expected if you are initializing PegasusForConditionalGeneration from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n",
      "Some weights of PegasusForConditionalGeneration were not initialized from the model checkpoint at pszemraj/pegasus-x-large-book-summary and are newly initialized: ['model.decoder.layers.7.encoder_attn.q_proj.bias', 'model.encoder.layers.6.self_attn.q_proj.bias', 'model.decoder.layers.4.self_attn.out_proj.bias', 'model.decoder.layers.6.encoder_attn.q_proj.bias', 'model.decoder.layers.12.self_attn.v_proj.bias', 'model.decoder.layers.6.encoder_attn.k_proj.bias', 'model.decoder.layers.7.encoder_attn.out_proj.bias', 'model.decoder.layers.3.encoder_attn.v_proj.bias', 'model.decoder.layers.7.encoder_attn.v_proj.bias', 'model.decoder.layers.13.encoder_attn.out_proj.bias', 'model.decoder.layers.10.self_attn.k_proj.bias', 'model.decoder.layers.2.self_attn.out_proj.bias', 'model.encoder.layers.12.self_attn.k_proj.bias', 'model.decoder.layers.9.encoder_attn.k_proj.bias', 'model.decoder.layers.11.encoder_attn.v_proj.bias', 'model.decoder.layers.7.encoder_attn.k_proj.bias', 'model.encoder.layers.7.self_attn.v_proj.bias', 'model.encoder.layers.12.self_attn.q_proj.bias', 'model.encoder.layers.0.self_attn.v_proj.bias', 'model.decoder.layers.2.self_attn.v_proj.bias', 'model.decoder.layers.11.self_attn.q_proj.bias', 'model.decoder.layers.5.self_attn.v_proj.bias', 'model.decoder.layers.15.self_attn.v_proj.bias', 'model.decoder.layers.0.encoder_attn.v_proj.bias', 'model.decoder.layers.12.self_attn.out_proj.bias', 'model.encoder.layers.0.self_attn.out_proj.bias', 'model.encoder.layers.10.self_attn.out_proj.bias', 'model.encoder.layers.13.self_attn.k_proj.bias', 'model.encoder.layers.10.self_attn.v_proj.bias', 'model.decoder.layers.6.self_attn.out_proj.bias', 'model.decoder.layers.14.encoder_attn.out_proj.bias', 'model.decoder.layers.3.encoder_attn.out_proj.bias', 'model.encoder.layers.14.self_attn.q_proj.bias', 'model.decoder.layers.2.encoder_attn.out_proj.bias', 'model.decoder.layers.6.encoder_attn.v_proj.bias', 'model.decoder.layers.10.self_attn.v_proj.bias', 'model.encoder.layers.9.self_attn.q_proj.bias', 'model.decoder.layers.4.encoder_attn.q_proj.bias', 'model.decoder.layers.5.encoder_attn.k_proj.bias', 'model.encoder.layers.14.self_attn.k_proj.bias', 'model.decoder.layers.7.self_attn.out_proj.bias', 'model.decoder.layers.0.encoder_attn.out_proj.bias', 'model.decoder.layers.4.encoder_attn.v_proj.bias', 'model.encoder.layers.13.self_attn.out_proj.bias', 'model.decoder.layers.5.self_attn.out_proj.bias', 'model.decoder.layers.2.self_attn.k_proj.bias', 'model.decoder.layers.0.self_attn.k_proj.bias', 'model.decoder.layers.1.self_attn.v_proj.bias', 'model.decoder.layers.4.self_attn.k_proj.bias', 'model.encoder.layers.3.self_attn.v_proj.bias', 'model.decoder.layers.10.encoder_attn.v_proj.bias', 'model.encoder.layers.10.self_attn.k_proj.bias', 'model.decoder.layers.14.encoder_attn.q_proj.bias', 'model.encoder.layers.11.self_attn.k_proj.bias', 'model.decoder.layers.7.self_attn.q_proj.bias', 'model.encoder.layers.9.self_attn.v_proj.bias', 'model.decoder.layers.5.self_attn.k_proj.bias', 'model.decoder.layers.8.encoder_attn.out_proj.bias', 'model.decoder.layers.1.encoder_attn.out_proj.bias', 'model.decoder.layers.14.encoder_attn.k_proj.bias', 'model.decoder.layers.4.encoder_attn.k_proj.bias', 'model.decoder.layers.6.self_attn.q_proj.bias', 'model.decoder.layers.3.encoder_attn.q_proj.bias', 'model.decoder.layers.8.self_attn.out_proj.bias', 'model.encoder.layers.8.self_attn.q_proj.bias', 'model.decoder.layers.4.self_attn.q_proj.bias', 'model.decoder.layers.11.self_attn.k_proj.bias', 'model.decoder.layers.9.self_attn.out_proj.bias', 'model.decoder.layers.10.encoder_attn.q_proj.bias', 'model.encoder.layers.3.self_attn.out_proj.bias', 'model.encoder.layers.12.self_attn.v_proj.bias', 'model.encoder.layers.6.self_attn.out_proj.bias', 'model.decoder.layers.2.encoder_attn.v_proj.bias', 'model.decoder.layers.3.self_attn.out_proj.bias', 'model.encoder.layers.11.self_attn.q_proj.bias', 'model.encoder.layers.13.self_attn.q_proj.bias', 'model.decoder.layers.14.self_attn.out_proj.bias', 'model.encoder.layers.13.self_attn.v_proj.bias', 'model.decoder.layers.0.self_attn.v_proj.bias', 'model.decoder.layers.2.self_attn.q_proj.bias', 'model.decoder.layers.13.self_attn.k_proj.bias', 'model.encoder.layers.8.self_attn.k_proj.bias', 'model.decoder.layers.14.self_attn.k_proj.bias', 'model.decoder.layers.3.encoder_attn.k_proj.bias', 'model.encoder.layers.8.self_attn.v_proj.bias', 'model.decoder.layers.10.encoder_attn.out_proj.bias', 'model.decoder.layers.0.encoder_attn.q_proj.bias', 'model.decoder.layers.11.encoder_attn.k_proj.bias', 'model.decoder.layers.8.self_attn.q_proj.bias', 'model.decoder.layers.15.encoder_attn.out_proj.bias', 'model.encoder.layers.7.self_attn.q_proj.bias', 'model.encoder.layers.12.self_attn.out_proj.bias', 'model.decoder.layers.12.self_attn.q_proj.bias', 'model.encoder.layers.15.self_attn.out_proj.bias', 'model.encoder.layers.15.self_attn.v_proj.bias', 'model.encoder.layers.5.self_attn.q_proj.bias', 'model.decoder.layers.13.encoder_attn.q_proj.bias', 'model.decoder.layers.7.self_attn.k_proj.bias', 'model.decoder.layers.12.self_attn.k_proj.bias', 'model.decoder.layers.10.self_attn.q_proj.bias', 'model.encoder.layers.3.self_attn.k_proj.bias', 'model.decoder.layers.13.self_attn.v_proj.bias', 'model.decoder.layers.0.encoder_attn.k_proj.bias', 'model.decoder.layers.15.encoder_attn.k_proj.bias', 'model.decoder.layers.6.encoder_attn.out_proj.bias', 'model.decoder.layers.4.encoder_attn.out_proj.bias', 'model.decoder.layers.15.self_attn.q_proj.bias', 'model.encoder.layers.15.self_attn.k_proj.bias', 'model.encoder.layers.2.self_attn.k_proj.bias', 'model.encoder.layers.14.self_attn.v_proj.bias', 'model.decoder.layers.8.encoder_attn.k_proj.bias', 'model.encoder.layers.4.self_attn.q_proj.bias', 'model.encoder.layers.8.self_attn.out_proj.bias', 'model.encoder.layers.1.self_attn.q_proj.bias', 'model.decoder.layers.12.encoder_attn.v_proj.bias', 'model.decoder.layers.2.encoder_attn.k_proj.bias', 'model.encoder.layers.1.self_attn.v_proj.bias', 'model.encoder.layers.4.self_attn.k_proj.bias', 'model.encoder.layers.4.self_attn.out_proj.bias', 'model.decoder.layers.3.self_attn.v_proj.bias', 'model.decoder.layers.14.self_attn.v_proj.bias', 'model.decoder.layers.3.self_attn.q_proj.bias', 'model.decoder.layers.13.encoder_attn.k_proj.bias', 'model.decoder.layers.15.encoder_attn.v_proj.bias', 'model.decoder.layers.0.self_attn.q_proj.bias', 'model.encoder.layers.1.self_attn.k_proj.bias', 'model.decoder.layers.13.self_attn.q_proj.bias', 'model.decoder.layers.5.encoder_attn.out_proj.bias', 'model.decoder.layers.12.encoder_attn.q_proj.bias', 'model.encoder.layers.7.self_attn.k_proj.bias', 'model.encoder.layers.0.self_attn.k_proj.bias', 'model.decoder.layers.7.self_attn.v_proj.bias', 'model.decoder.layers.13.encoder_attn.v_proj.bias', 'model.decoder.layers.12.encoder_attn.k_proj.bias', 'model.decoder.layers.5.self_attn.q_proj.bias', 'model.decoder.layers.11.encoder_attn.q_proj.bias', 'model.decoder.layers.8.self_attn.k_proj.bias', 'model.encoder.layers.1.self_attn.out_proj.bias', 'model.encoder.layers.5.self_attn.out_proj.bias', 'model.decoder.layers.1.self_attn.q_proj.bias', 'model.decoder.layers.1.self_attn.k_proj.bias', 'model.encoder.layers.15.self_attn.q_proj.bias', 'model.encoder.layers.9.self_attn.k_proj.bias', 'model.decoder.layers.9.self_attn.k_proj.bias', 'model.encoder.layers.6.self_attn.k_proj.bias', 'model.decoder.layers.1.encoder_attn.v_proj.bias', 'model.decoder.layers.9.encoder_attn.q_proj.bias', 'model.encoder.layers.11.self_attn.out_proj.bias', 'model.decoder.layers.6.self_attn.v_proj.bias', 'model.decoder.layers.9.encoder_attn.v_proj.bias', 'model.decoder.layers.14.encoder_attn.v_proj.bias', 'model.decoder.layers.11.self_attn.v_proj.bias', 'model.decoder.layers.1.self_attn.out_proj.bias', 'model.decoder.layers.11.encoder_attn.out_proj.bias', 'model.decoder.layers.5.encoder_attn.v_proj.bias', 'model.decoder.layers.15.encoder_attn.q_proj.bias', 'model.encoder.layers.7.self_attn.out_proj.bias', 'model.decoder.layers.2.encoder_attn.q_proj.bias', 'model.encoder.layers.4.self_attn.v_proj.bias', 'model.encoder.layers.2.self_attn.q_proj.bias', 'model.decoder.layers.9.self_attn.v_proj.bias', 'model.decoder.layers.8.encoder_attn.q_proj.bias', 'model.decoder.layers.8.encoder_attn.v_proj.bias', 'model.encoder.layers.10.self_attn.q_proj.bias', 'model.decoder.layers.5.encoder_attn.q_proj.bias', 'model.decoder.layers.13.self_attn.out_proj.bias', 'model.decoder.layers.1.encoder_attn.k_proj.bias', 'model.encoder.layers.11.self_attn.v_proj.bias', 'model.decoder.layers.15.self_attn.out_proj.bias', 'model.decoder.layers.0.self_attn.out_proj.bias', 'model.encoder.layers.3.self_attn.q_proj.bias', 'model.encoder.layers.2.self_attn.v_proj.bias', 'model.decoder.layers.15.self_attn.k_proj.bias', 'model.decoder.layers.4.self_attn.v_proj.bias', 'model.encoder.layers.9.self_attn.out_proj.bias', 'model.encoder.layers.5.self_attn.v_proj.bias', 'model.encoder.layers.14.self_attn.out_proj.bias', 'model.decoder.layers.1.encoder_attn.q_proj.bias', 'model.decoder.layers.10.encoder_attn.k_proj.bias', 'model.decoder.layers.10.self_attn.out_proj.bias', 'model.decoder.layers.14.self_attn.q_proj.bias', 'model.decoder.layers.12.encoder_attn.out_proj.bias', 'model.decoder.layers.6.self_attn.k_proj.bias', 'model.decoder.layers.11.self_attn.out_proj.bias', 'model.decoder.layers.9.encoder_attn.out_proj.bias', 'model.decoder.layers.3.self_attn.k_proj.bias', 'model.encoder.layers.6.self_attn.v_proj.bias', 'model.encoder.layers.5.self_attn.k_proj.bias', 'model.decoder.layers.8.self_attn.v_proj.bias', 'model.encoder.layers.0.self_attn.q_proj.bias', 'model.encoder.layers.2.self_attn.out_proj.bias', 'model.decoder.layers.9.self_attn.q_proj.bias']\n",
      "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
     ]
    }
   ],
   "source": [
    "#load tokenizer\n",
    "tokenizer = PegasusTokenizer.from_pretrained(\"pszemraj/pegasus-x-large-book-summary\")\n",
    "#load model\n",
    "model = PegasusForConditionalGeneration.from_pretrained(\"pszemraj/pegasus-x-large-book-summary\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 108,
   "metadata": {},
   "outputs": [],
   "source": [
    "#function to ensure summaries end with punctuation\n",
    "def cut(sum):\n",
    "    last_punc_idx = max(sum.rfind(p) for p in string.punctuation)\n",
    "    output = sum[:last_punc_idx + 1]\n",
    "    return output\n",
    "\n",
    "\n",
    "#function to summarize\n",
    "\n",
    "def summarize(des_list):\n",
    "    if \"No description available.\" in des_list:\n",
    "       idx = des_list.index(\"No description available.\")\n",
    "       des = des_list.copy()\n",
    "       des.pop(idx)\n",
    "       rest = summarize(des)\n",
    "       rest.insert(idx,'No description available.')\n",
    "       return rest\n",
    "    else: \n",
    "        # Tokenize all the descriptions in the list\n",
    "        encoded_inputs = tokenizer(des_list, truncation=True, padding=\"longest\", return_tensors=\"pt\")\n",
    "\n",
    "        # Generate summaries for all the inputs\n",
    "        summaries = model.generate(**encoded_inputs, max_new_tokens=100)\n",
    "\n",
    "        # Decode the summaries and process them\n",
    "        outputs = tokenizer.batch_decode(summaries, skip_special_tokens=True)\n",
    "        outputs = list(map(cut, outputs))\n",
    "        return outputs\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "base",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.10"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}