dbleek commited on
Commit
0d30c2b
1 Parent(s): 1e95f51

added new classifier

Browse files
Files changed (2) hide show
  1. milestone-3.py +25 -22
  2. patent_classifier_v4.pt +3 -0
milestone-3.py CHANGED
@@ -1,8 +1,8 @@
1
  import streamlit as st
2
  import torch
 
3
  from datasets import load_dataset
4
  from transformers import AutoTokenizer
5
- from transformers import AutoModelForSequenceClassification
6
  from transformers import pipeline
7
 
8
  # Load HUPD dataset
@@ -21,21 +21,24 @@ dataset_dict = load_dataset(
21
  filtered_dataset = dataset_dict["validation"].filter(
22
  lambda e: e["decision"] == "ACCEPTED" or e["decision"] == "REJECTED"
23
  )
24
- dataset = filtered_dataset.shuffle(seed=42).select(range(20))
 
 
 
25
  dataset = dataset.sort("patent_number")
26
 
27
  # Create pipeline using model trainned on Colab
28
- model = torch.load("patent_classifier_v2.pt", map_location=torch.device("cpu"))
29
  tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased")
30
- classifier = pipeline("text-classification", model=model, tokenizer=tokenizer)
 
31
 
32
-
33
- def load_patent():
34
  selected_application = dataset.select([applications[st.session_state.id]])
35
  st.session_state.abstract = selected_application["abstract"][0]
36
  st.session_state.claims = selected_application["claims"][0]
37
  st.session_state.title = selected_application["title"][0]
38
-
39
 
40
  st.title("CS-GY-6613 Project Milestone 3")
41
 
@@ -44,11 +47,12 @@ applications = {}
44
  for ds_index, example in enumerate(dataset):
45
  applications.update({example["patent_number"]: ds_index})
46
  st.selectbox(
47
- "Select a patent application:", applications, on_change=load_patent, key="id"
48
  )
49
 
50
- # Application title displayed for additional context only, not used with model
51
- st.text_area("Title", key="title", value=dataset[0]["title"], height=50)
 
52
 
53
  # Classifier input form
54
  with st.form("Input Form"):
@@ -61,16 +65,15 @@ with st.form("Input Form"):
61
  submitted = st.form_submit_button("Get Patentability Score")
62
 
63
  if submitted:
64
- selected_application = dataset.select([applications[st.session_state.id]])
65
- res = classifier(abstract, claims)
66
- if res[0]["label"] == "LABEL_0":
67
- pred = "ACCEPTED"
68
- elif res[0]["label"] == "LABEL_1":
69
- pred = "REJECTED"
70
- score = res[0]["score"]
71
- label = selected_application["decision"][0]
72
- result = st.markdown(
73
- "This text was classified as **{}** with a confidence score of **{}**.".format(
74
- pred, score
75
- )
76
  )
 
 
 
 
1
  import streamlit as st
2
  import torch
3
+ from datasets import combine
4
  from datasets import load_dataset
5
  from transformers import AutoTokenizer
 
6
  from transformers import pipeline
7
 
8
  # Load HUPD dataset
 
21
  filtered_dataset = dataset_dict["validation"].filter(
22
  lambda e: e["decision"] == "ACCEPTED" or e["decision"] == "REJECTED"
23
  )
24
+ seed = 88
25
+ accepted = filtered_dataset.filter(lambda e: e["decision"] == "ACCEPTED").shuffle(seed).select(range(5))
26
+ rejected = filtered_dataset.filter(lambda e: e["decision"] == "REJECTED").shuffle(seed).select(range(5))
27
+ dataset = combine.concatenate_datasets([accepted, rejected])
28
  dataset = dataset.sort("patent_number")
29
 
30
  # Create pipeline using model trainned on Colab
31
+ model = torch.load("patent_classifier_v4.pt", map_location=torch.device("cpu"))
32
  tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased")
33
+ tokenizer_kwargs = {'padding':True,'truncation':True}
34
+ classifier = pipeline("text-classification", model=model, tokenizer=tokenizer, **tokenizer_kwargs)
35
 
36
+ def load_data():
 
37
  selected_application = dataset.select([applications[st.session_state.id]])
38
  st.session_state.abstract = selected_application["abstract"][0]
39
  st.session_state.claims = selected_application["claims"][0]
40
  st.session_state.title = selected_application["title"][0]
41
+ st.session_state.decision = selected_application["decision"][0]
42
 
43
  st.title("CS-GY-6613 Project Milestone 3")
44
 
 
47
  for ds_index, example in enumerate(dataset):
48
  applications.update({example["patent_number"]: ds_index})
49
  st.selectbox(
50
+ "Select a sample patent application:", applications, on_change=load_data, key="id"
51
  )
52
 
53
+ # Sample title/decision displayed for additional context only, not used with model
54
+ st.text_input("Sample Title", key="title", value=dataset[0]["title"])
55
+ st.text_input("Sample Decision", key="decision", value=dataset[0]["decision"])
56
 
57
  # Classifier input form
58
  with st.form("Input Form"):
 
65
  submitted = st.form_submit_button("Get Patentability Score")
66
 
67
  if submitted:
68
+ tokens = tokenizer(abstract, claims, return_tensors='pt', **tokenizer_kwargs)
69
+ with torch.no_grad():
70
+ output = model(**tokens)
71
+ logits = output.logits
72
+ pred = torch.softmax(logits, dim=1)
73
+ score = pred[0][1] # index 1 of softmax output is probability that decision = ACCEPTED
74
+ st.markdown(
75
+ "This application's patentability score is **{}**".format(score)
 
 
 
 
76
  )
77
+
78
+
79
+
patent_classifier_v4.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ae0d471894ba6a7847254acda873e574837547b684b854eaa96efe3b593f8c2d
3
+ size 267882526