Spaces:
Runtime error
Runtime error
added the local bayes library, removed bayes from req.txt
Browse files- bayes/__init__.py +0 -0
- bayes/__pycache__/__init__.cpython-39.pyc +0 -0
- bayes/__pycache__/data_routines.cpython-39.pyc +0 -0
- bayes/__pycache__/explanations.cpython-39.pyc +0 -0
- bayes/__pycache__/models.cpython-39.pyc +0 -0
- bayes/__pycache__/regression.cpython-39.pyc +0 -0
- bayes/data_routines.py +218 -0
- bayes/explanations.py +701 -0
- bayes/models.py +163 -0
- bayes/regression.py +148 -0
- requirements.txt +0 -1
bayes/__init__.py
ADDED
File without changes
|
bayes/__pycache__/__init__.cpython-39.pyc
ADDED
Binary file (189 Bytes). View file
|
|
bayes/__pycache__/data_routines.cpython-39.pyc
ADDED
Binary file (6.18 kB). View file
|
|
bayes/__pycache__/explanations.cpython-39.pyc
ADDED
Binary file (17.9 kB). View file
|
|
bayes/__pycache__/models.cpython-39.pyc
ADDED
Binary file (5.28 kB). View file
|
|
bayes/__pycache__/regression.cpython-39.pyc
ADDED
Binary file (4.26 kB). View file
|
|
bayes/data_routines.py
ADDED
@@ -0,0 +1,218 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Routines for processing data."""
|
2 |
+
import numpy as np
|
3 |
+
import os
|
4 |
+
import pandas as pd
|
5 |
+
from PIL import Image
|
6 |
+
from skimage.segmentation import slic, mark_boundaries
|
7 |
+
|
8 |
+
import torch
|
9 |
+
from torchvision import datasets, transforms
|
10 |
+
|
11 |
+
# The number of segments to use for the images
|
12 |
+
NSEGMENTS = 20
|
13 |
+
PARAMS = {
|
14 |
+
'protected_class': 1,
|
15 |
+
'unprotected_class': 0,
|
16 |
+
'positive_outcome': 1,
|
17 |
+
'negative_outcome': 0
|
18 |
+
}
|
19 |
+
IMAGENET_LABELS = {
|
20 |
+
'french_bulldog': 245,
|
21 |
+
'scuba_diver': 983,
|
22 |
+
'corn': 987,
|
23 |
+
'broccoli': 927
|
24 |
+
}
|
25 |
+
|
26 |
+
def get_and_preprocess_compas_data():
|
27 |
+
"""Handle processing of COMPAS according to: https://github.com/propublica/compas-analysis
|
28 |
+
|
29 |
+
Parameters
|
30 |
+
----------
|
31 |
+
params : Params
|
32 |
+
Returns
|
33 |
+
----------
|
34 |
+
Pandas data frame X of processed data, np.ndarray y, and list of column names
|
35 |
+
"""
|
36 |
+
PROTECTED_CLASS = PARAMS['protected_class']
|
37 |
+
UNPROTECTED_CLASS = PARAMS['unprotected_class']
|
38 |
+
POSITIVE_OUTCOME = PARAMS['positive_outcome']
|
39 |
+
NEGATIVE_OUTCOME = PARAMS['negative_outcome']
|
40 |
+
|
41 |
+
compas_df = pd.read_csv("../data/compas-scores-two-years.csv", index_col=0)
|
42 |
+
compas_df = compas_df.loc[(compas_df['days_b_screening_arrest'] <= 30) &
|
43 |
+
(compas_df['days_b_screening_arrest'] >= -30) &
|
44 |
+
(compas_df['is_recid'] != -1) &
|
45 |
+
(compas_df['c_charge_degree'] != "O") &
|
46 |
+
(compas_df['score_text'] != "NA")]
|
47 |
+
|
48 |
+
compas_df['length_of_stay'] = (pd.to_datetime(compas_df['c_jail_out']) - pd.to_datetime(compas_df['c_jail_in'])).dt.days
|
49 |
+
X = compas_df[['age', 'two_year_recid','c_charge_degree', 'race', 'sex', 'priors_count', 'length_of_stay']]
|
50 |
+
|
51 |
+
# if person has high score give them the _negative_ model outcome
|
52 |
+
y = np.array([NEGATIVE_OUTCOME if score == 'High' else POSITIVE_OUTCOME for score in compas_df['score_text']])
|
53 |
+
sens = X.pop('race')
|
54 |
+
|
55 |
+
# assign African-American as the protected class
|
56 |
+
X = pd.get_dummies(X)
|
57 |
+
sensitive_attr = np.array(pd.get_dummies(sens).pop('African-American'))
|
58 |
+
X['race'] = sensitive_attr
|
59 |
+
|
60 |
+
# make sure everything is lining up
|
61 |
+
assert all((sens == 'African-American') == (X['race'] == PROTECTED_CLASS))
|
62 |
+
cols = [col for col in X]
|
63 |
+
|
64 |
+
categorical_features = [1, 4, 5, 6, 7, 8]
|
65 |
+
|
66 |
+
output = {
|
67 |
+
"X": X.values,
|
68 |
+
"y": y,
|
69 |
+
"column_names": cols,
|
70 |
+
"cat_indices": categorical_features
|
71 |
+
}
|
72 |
+
|
73 |
+
return output
|
74 |
+
|
75 |
+
def get_and_preprocess_german():
|
76 |
+
""""Handle processing of German. We use a preprocessed version of German from Ustun et. al.
|
77 |
+
https://arxiv.org/abs/1809.06514. Thanks Berk!
|
78 |
+
Parameters:
|
79 |
+
----------
|
80 |
+
params : Params
|
81 |
+
Returns:
|
82 |
+
----------
|
83 |
+
Pandas data frame X of processed data, np.ndarray y, and list of column names
|
84 |
+
"""
|
85 |
+
PROTECTED_CLASS = PARAMS['protected_class']
|
86 |
+
UNPROTECTED_CLASS = PARAMS['unprotected_class']
|
87 |
+
POSITIVE_OUTCOME = PARAMS['positive_outcome']
|
88 |
+
NEGATIVE_OUTCOME = PARAMS['negative_outcome']
|
89 |
+
|
90 |
+
X = pd.read_csv("../data/german_processed.csv")
|
91 |
+
y = X["GoodCustomer"]
|
92 |
+
|
93 |
+
X = X.drop(["GoodCustomer", "PurposeOfLoan"], axis=1)
|
94 |
+
X['Gender'] = [1 if v == "Male" else 0 for v in X['Gender'].values]
|
95 |
+
|
96 |
+
y = np.array([POSITIVE_OUTCOME if p == 1 else NEGATIVE_OUTCOME for p in y.values])
|
97 |
+
categorical_features = [0, 1, 2] + list(range(9, X.shape[1]))
|
98 |
+
|
99 |
+
output = {
|
100 |
+
"X": X.values,
|
101 |
+
"y": y,
|
102 |
+
"column_names": [c for c in X],
|
103 |
+
"cat_indices": categorical_features,
|
104 |
+
}
|
105 |
+
|
106 |
+
return output
|
107 |
+
|
108 |
+
def get_PIL_transf():
|
109 |
+
"""Gets the PIL image transformation."""
|
110 |
+
transf = transforms.Compose([
|
111 |
+
transforms.Resize((256, 256)),
|
112 |
+
transforms.CenterCrop(224)
|
113 |
+
])
|
114 |
+
return transf
|
115 |
+
|
116 |
+
def load_image(path):
|
117 |
+
"""Loads an image by path."""
|
118 |
+
with open(os.path.abspath(path), 'rb') as f:
|
119 |
+
with Image.open(f) as img:
|
120 |
+
return img.convert('RGB')
|
121 |
+
|
122 |
+
def get_imagenet(name, get_label=True):
|
123 |
+
"""Gets the imagenet data.
|
124 |
+
|
125 |
+
Arguments:
|
126 |
+
name: The name of the imagenet dataset
|
127 |
+
"""
|
128 |
+
images_paths = []
|
129 |
+
|
130 |
+
# Store all the paths of the images
|
131 |
+
data_dir = os.path.join("../data", name)
|
132 |
+
for (dirpath, dirnames, filenames) in os.walk(data_dir):
|
133 |
+
for fn in filenames:
|
134 |
+
if fn != ".DS_Store":
|
135 |
+
images_paths.append(os.path.join(dirpath, fn))
|
136 |
+
|
137 |
+
# Load & do transforms for the images
|
138 |
+
pill_transf = get_PIL_transf()
|
139 |
+
images, segs = [], []
|
140 |
+
for img_path in images_paths:
|
141 |
+
img = load_image(img_path)
|
142 |
+
PIL_transformed_image = np.array(pill_transf(img))
|
143 |
+
segments = slic(PIL_transformed_image, n_segments=NSEGMENTS, compactness=100, sigma=1)
|
144 |
+
|
145 |
+
images.append(PIL_transformed_image)
|
146 |
+
segs.append(segments)
|
147 |
+
|
148 |
+
images = np.array(images)
|
149 |
+
|
150 |
+
if get_label:
|
151 |
+
assert name in IMAGENET_LABELS, "Get label set to True but name not in known imagenet labels"
|
152 |
+
y = np.ones(images.shape[0]) * IMAGENET_LABELS[name]
|
153 |
+
else:
|
154 |
+
y = np.ones(images.shape[0]) * -1
|
155 |
+
|
156 |
+
segs = np.array(segs)
|
157 |
+
|
158 |
+
output = {
|
159 |
+
"X": images,
|
160 |
+
"y": y,
|
161 |
+
"segments": segs
|
162 |
+
}
|
163 |
+
|
164 |
+
return output
|
165 |
+
|
166 |
+
|
167 |
+
def get_mnist(num):
|
168 |
+
"""Gets the MNIST data for a certain digit.
|
169 |
+
|
170 |
+
Arguments:
|
171 |
+
num: The mnist digit to get
|
172 |
+
"""
|
173 |
+
|
174 |
+
# Get the mnist data
|
175 |
+
test_loader = torch.utils.data.DataLoader(datasets.MNIST('../data/mnist',
|
176 |
+
train=False,
|
177 |
+
download=True,
|
178 |
+
transform=transforms.Compose([transforms.ToTensor(),
|
179 |
+
transforms.Normalize((0.1307,), (0.3081,))
|
180 |
+
])),
|
181 |
+
batch_size=1,
|
182 |
+
shuffle=False)
|
183 |
+
|
184 |
+
all_test_mnist_of_label_num, all_test_segments_of_label_num = [], []
|
185 |
+
|
186 |
+
# Get all instances of label num
|
187 |
+
for data, y in test_loader:
|
188 |
+
if y[0] == num:
|
189 |
+
# Apply segmentation
|
190 |
+
sample = np.squeeze(data.numpy().astype('double'),axis=0)
|
191 |
+
segments = slic(sample.reshape(28,28,1), n_segments=NSEGMENTS, compactness=1, sigma=0.1).reshape(1,28,28)
|
192 |
+
all_test_mnist_of_label_num.append(sample)
|
193 |
+
all_test_segments_of_label_num.append(segments)
|
194 |
+
|
195 |
+
all_test_mnist_of_label_num = np.array(all_test_mnist_of_label_num)
|
196 |
+
all_test_segments_of_label_num = np.array(all_test_segments_of_label_num)
|
197 |
+
|
198 |
+
output = {
|
199 |
+
"X": all_test_mnist_of_label_num,
|
200 |
+
"y": np.ones(all_test_mnist_of_label_num.shape[0]) * num,
|
201 |
+
"segments": all_test_segments_of_label_num
|
202 |
+
}
|
203 |
+
|
204 |
+
return output
|
205 |
+
|
206 |
+
def get_dataset_by_name(name, get_label=True):
|
207 |
+
if name == "compas":
|
208 |
+
d = get_and_preprocess_compas_data()
|
209 |
+
elif name == "german":
|
210 |
+
d = get_and_preprocess_german()
|
211 |
+
elif "mnist" in name:
|
212 |
+
d = get_mnist(int(name[-1]))
|
213 |
+
elif "imagenet" in name:
|
214 |
+
d = get_imagenet(name[9:], get_label=get_label)
|
215 |
+
else:
|
216 |
+
raise NameError("Unkown dataset %s", name)
|
217 |
+
d['name'] = name
|
218 |
+
return d
|
bayes/explanations.py
ADDED
@@ -0,0 +1,701 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Bayesian Local Explanations.
|
2 |
+
|
3 |
+
This code implements bayesian local explanations. The code supports the LIME & SHAP
|
4 |
+
kernels. Along with the LIME & SHAP feature importances, bayesian local explanations
|
5 |
+
also support uncertainty expression over the feature importances.
|
6 |
+
"""
|
7 |
+
import logging
|
8 |
+
|
9 |
+
from copy import deepcopy
|
10 |
+
from functools import reduce
|
11 |
+
from multiprocessing import Pool
|
12 |
+
import numpy as np
|
13 |
+
import operator as op
|
14 |
+
from tqdm import tqdm
|
15 |
+
|
16 |
+
import sklearn
|
17 |
+
import sklearn.preprocessing
|
18 |
+
from sklearn.linear_model import Ridge, Lasso
|
19 |
+
from lime import lime_image, lime_tabular
|
20 |
+
|
21 |
+
from bayes.regression import BayesianLinearRegression
|
22 |
+
|
23 |
+
LDATA, LINVERSE, LSCALED, LDISTANCES, LY = list(range(5))
|
24 |
+
SDATA, SINVERSE, SY = list(range(3))
|
25 |
+
|
26 |
+
class BayesLocalExplanations:
|
27 |
+
"""Bayesian Local Explanations.
|
28 |
+
|
29 |
+
This class implements the bayesian local explanations.
|
30 |
+
"""
|
31 |
+
def __init__(self,
|
32 |
+
training_data,
|
33 |
+
data="image",
|
34 |
+
kernel="lime",
|
35 |
+
credible_interval=95,
|
36 |
+
mode="classification",
|
37 |
+
categorical_features=[],
|
38 |
+
discretize_continuous=True,
|
39 |
+
save_logs=False,
|
40 |
+
log_file_name="bayes.log",
|
41 |
+
width=0.75,
|
42 |
+
verbose=False):
|
43 |
+
"""Initialize the local explanations.
|
44 |
+
|
45 |
+
Arguments:
|
46 |
+
training_data: The
|
47 |
+
data: The type of data, either "image" or "tabular"
|
48 |
+
kernel: The kernel to use, either "lime" or "shap"
|
49 |
+
credible_interval: The % credible interval to use for the feature importance
|
50 |
+
uncertainty.
|
51 |
+
mode: Whether to run with classification or regression.
|
52 |
+
categorical_features: The indices of the categorical features, if in regression mode.
|
53 |
+
save_logs: Whether to save logs from the run.
|
54 |
+
log_file_name: The name of log file.
|
55 |
+
"""
|
56 |
+
|
57 |
+
assert kernel in ["lime", "shap"], f"Kernel must be one of lime or shap, not {kernel}"
|
58 |
+
assert data in ["image", "tabular"], f"Data must be one of image or tabular, not {data}"
|
59 |
+
assert mode in ["classification"], "Others modes like regression are not implemented"
|
60 |
+
|
61 |
+
if save_logs:
|
62 |
+
logging.basicConfig(filename=log_file_name,
|
63 |
+
filemode='a',
|
64 |
+
level=logging.INFO)
|
65 |
+
|
66 |
+
logging.info("==============================================")
|
67 |
+
logging.info("Initializing Bayes%s %s explanations", kernel, data)
|
68 |
+
logging.info("~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~")
|
69 |
+
|
70 |
+
self.cred_int = credible_interval
|
71 |
+
self.data = data
|
72 |
+
self.kernel = kernel
|
73 |
+
self.mode = mode
|
74 |
+
self.categorical_features = categorical_features
|
75 |
+
self.discretize_continuous = discretize_continuous
|
76 |
+
self.verbose = verbose
|
77 |
+
self.width = width * np.sqrt(training_data.shape[1])
|
78 |
+
|
79 |
+
logging.info("Setting mode to %s", mode)
|
80 |
+
logging.info("Credible interval set to %s", self.cred_int)
|
81 |
+
|
82 |
+
if kernel == "shap" and data == "tabular":
|
83 |
+
logging.info("Setting discretize_continuous to True, due to shapley sampling")
|
84 |
+
discretize_continuous = True
|
85 |
+
|
86 |
+
self.training_data = training_data
|
87 |
+
self._run_init(training_data)
|
88 |
+
|
89 |
+
def _run_init(self, training_data):
|
90 |
+
if self.kernel == "lime":
|
91 |
+
lime_tab_exp = lime_tabular.LimeTabularExplainer(training_data,
|
92 |
+
mode=self.mode,
|
93 |
+
categorical_features=self.categorical_features,
|
94 |
+
discretize_continuous=self.discretize_continuous)
|
95 |
+
self.lime_info = lime_tab_exp
|
96 |
+
elif self.kernel == "shap":
|
97 |
+
# Discretization forcibly set to true for shap sampling on initialization
|
98 |
+
shap_tab_exp = lime_tabular.LimeTabularExplainer(training_data,
|
99 |
+
mode=self.mode,
|
100 |
+
categorical_features=self.categorical_features,
|
101 |
+
discretize_continuous=self.discretize_continuous)
|
102 |
+
self.shap_info = shap_tab_exp
|
103 |
+
else:
|
104 |
+
raise NotImplementedError
|
105 |
+
|
106 |
+
def _log_args(self, args):
|
107 |
+
"""Logs arguments to function."""
|
108 |
+
logging.info(args)
|
109 |
+
|
110 |
+
def _shap_tabular_perturb_n_samples(self,
|
111 |
+
data,
|
112 |
+
n_samples,
|
113 |
+
max_coefs=None):
|
114 |
+
"""Generates n shap perturbations"""
|
115 |
+
if max_coefs is None:
|
116 |
+
max_coefs = np.arange(data.shape[0])
|
117 |
+
pre_rdata, pre_inverse = self.shap_info._LimeTabularExplainer__data_inverse(data_row=data,
|
118 |
+
num_samples=n_samples)
|
119 |
+
rdata = pre_rdata[:, max_coefs]
|
120 |
+
inverse = np.tile(data, (n_samples, 1))
|
121 |
+
inverse[:, max_coefs] = pre_inverse[:, max_coefs]
|
122 |
+
return rdata, inverse
|
123 |
+
|
124 |
+
def _lime_tabular_perturb_n_samples(self,
|
125 |
+
data,
|
126 |
+
n_samples):
|
127 |
+
"""Generates n_perturbations for LIME."""
|
128 |
+
rdata, inverse = self.lime_info._LimeTabularExplainer__data_inverse(data_row=data,
|
129 |
+
num_samples=n_samples)
|
130 |
+
scaled_data = (rdata - self.lime_info.scaler.mean_) / self.lime_info.scaler.scale_
|
131 |
+
distances = sklearn.metrics.pairwise_distances(
|
132 |
+
scaled_data,
|
133 |
+
scaled_data[0].reshape(1, -1),
|
134 |
+
metric='euclidean'
|
135 |
+
).ravel()
|
136 |
+
return rdata, inverse, scaled_data, distances
|
137 |
+
|
138 |
+
def _stack_tabular_return(self, existing_return, perturb_return):
|
139 |
+
"""Stacks data from new tabular return to existing return."""
|
140 |
+
if len(existing_return) == 0:
|
141 |
+
return perturb_return
|
142 |
+
new_return = []
|
143 |
+
for i, item in enumerate(existing_return):
|
144 |
+
new_return.append(np.concatenate((item, perturb_return[i]), axis=0))
|
145 |
+
return new_return
|
146 |
+
|
147 |
+
def _select_indices_from_data(self, perturb_return, indices, predictions):
|
148 |
+
"""Gets each element from the perturb return according to indices, then appends the predictions."""
|
149 |
+
# Previoulsy had this set to range(4)
|
150 |
+
temp = [perturb_return[i][indices] for i in range(len(perturb_return))]
|
151 |
+
temp.append(predictions)
|
152 |
+
return temp
|
153 |
+
|
154 |
+
def shap_tabular_focus_sample(self,
|
155 |
+
data,
|
156 |
+
classifier_f,
|
157 |
+
label,
|
158 |
+
n_samples,
|
159 |
+
focus_sample_batch_size,
|
160 |
+
focus_sample_initial_points,
|
161 |
+
to_consider=10_000,
|
162 |
+
tempurature=1e-2,
|
163 |
+
enumerate_initial=True):
|
164 |
+
"""Focus sample n_samples perturbations for lime tabular."""
|
165 |
+
assert focus_sample_initial_points > 0, "Initial focusing sample points cannot be <= 0"
|
166 |
+
current_n_perturbations = 0
|
167 |
+
|
168 |
+
# Get 1's coalitions, if requested
|
169 |
+
if enumerate_initial:
|
170 |
+
enumerate_init_p = self._enumerate_initial_shap(data)
|
171 |
+
current_n_perturbations += enumerate_init_p[0].shape[0]
|
172 |
+
else:
|
173 |
+
enumerate_init_p = None
|
174 |
+
|
175 |
+
if self.verbose:
|
176 |
+
pbar = tqdm(total=n_samples)
|
177 |
+
pbar.update(current_n_perturbations)
|
178 |
+
|
179 |
+
# Get initial points
|
180 |
+
if current_n_perturbations < focus_sample_initial_points:
|
181 |
+
initial_perturbations = self._shap_tabular_perturb_n_samples(data, focus_sample_initial_points - current_n_perturbations)
|
182 |
+
|
183 |
+
if enumerate_init_p is not None:
|
184 |
+
current_perturbations = self._stack_tabular_return(enumerate_init_p, initial_perturbations)
|
185 |
+
else:
|
186 |
+
current_perturbations = initial_perturbations
|
187 |
+
|
188 |
+
current_n_perturbations += initial_perturbations[0].shape[0]
|
189 |
+
else:
|
190 |
+
current_perturbations = enumerate_init_p
|
191 |
+
|
192 |
+
current_perturbations = list(current_perturbations)
|
193 |
+
|
194 |
+
# Store initial predictions
|
195 |
+
current_perturbations.append(classifier_f(current_perturbations[SINVERSE])[:, label])
|
196 |
+
if self.verbose:
|
197 |
+
pbar.update(initial_perturbations[0].shape[0])
|
198 |
+
|
199 |
+
while current_n_perturbations < n_samples:
|
200 |
+
current_batch_size = min(focus_sample_batch_size, n_samples - current_n_perturbations)
|
201 |
+
|
202 |
+
# Init current BLR
|
203 |
+
blr = BayesianLinearRegression(percent=self.cred_int)
|
204 |
+
weights = self._get_shap_weights(current_perturbations[SDATA], current_perturbations[SDATA].shape[1])
|
205 |
+
blr.fit(current_perturbations[SDATA], current_perturbations[-1], weights, compute_creds=False)
|
206 |
+
|
207 |
+
candidate_perturbations = self._shap_tabular_perturb_n_samples(data, to_consider)
|
208 |
+
_, var = blr.predict(candidate_perturbations[SINVERSE])
|
209 |
+
|
210 |
+
# Get sampling weighting
|
211 |
+
var /= tempurature
|
212 |
+
exp_var = np.exp(var)
|
213 |
+
all_exp = np.sum(exp_var)
|
214 |
+
tempurature_scaled_weights = exp_var / all_exp
|
215 |
+
|
216 |
+
# Get sampled indices
|
217 |
+
least_confident_sample = np.random.choice(len(var), size=current_batch_size, p=tempurature_scaled_weights, replace=True)
|
218 |
+
|
219 |
+
# Get predictions
|
220 |
+
cy = classifier_f(candidate_perturbations[SINVERSE][least_confident_sample])[:, label]
|
221 |
+
|
222 |
+
new_perturbations = self._select_indices_from_data(candidate_perturbations, least_confident_sample, cy)
|
223 |
+
current_perturbations = self._stack_tabular_return(current_perturbations, new_perturbations)
|
224 |
+
current_n_perturbations += new_perturbations[0].shape[0]
|
225 |
+
|
226 |
+
if self.verbose:
|
227 |
+
pbar.update(new_perturbations[0].shape[0])
|
228 |
+
|
229 |
+
return current_perturbations
|
230 |
+
|
231 |
+
def lime_tabular_focus_sample(self,
|
232 |
+
data,
|
233 |
+
classifier_f,
|
234 |
+
label,
|
235 |
+
n_samples,
|
236 |
+
focus_sample_batch_size,
|
237 |
+
focus_sample_initial_points,
|
238 |
+
to_consider=10_000,
|
239 |
+
tempurature=5e-4,
|
240 |
+
existing_data=[]):
|
241 |
+
"""Focus sample n_samples perturbations for lime tabular."""
|
242 |
+
current_n_perturbations = 0
|
243 |
+
|
244 |
+
# Get initial focus sampling batch
|
245 |
+
if len(existing_data) < focus_sample_initial_points:
|
246 |
+
# If there's existing data, make sure we only sample up to existing_data points
|
247 |
+
initial_perturbations = self._lime_tabular_perturb_n_samples(data, focus_sample_initial_points - len(existing_data))
|
248 |
+
current_perturbations = self._stack_tabular_return(existing_data, initial_perturbations)
|
249 |
+
else:
|
250 |
+
current_perturbations = existing_data
|
251 |
+
|
252 |
+
if self.verbose:
|
253 |
+
pbar = tqdm(total=n_samples)
|
254 |
+
|
255 |
+
current_perturbations = list(current_perturbations)
|
256 |
+
current_n_perturbations += initial_perturbations[0].shape[0]
|
257 |
+
|
258 |
+
# Store predictions on initial data
|
259 |
+
current_perturbations.append(classifier_f(current_perturbations[LINVERSE])[:, label])
|
260 |
+
if self.verbose:
|
261 |
+
pbar.update(initial_perturbations[0].shape[0])
|
262 |
+
|
263 |
+
# Sample up to n_samples
|
264 |
+
while current_n_perturbations < n_samples:
|
265 |
+
|
266 |
+
# If batch size would exceed n_samples, only sample enough to reach n_samples
|
267 |
+
current_batch_size = min(focus_sample_batch_size, n_samples - current_n_perturbations)
|
268 |
+
|
269 |
+
# Init current BLR
|
270 |
+
blr = BayesianLinearRegression(percent=self.cred_int)
|
271 |
+
# Get weights on current distances
|
272 |
+
weights = self._lime_kernel(current_perturbations[LDISTANCES], self.width)
|
273 |
+
# Fit blr on current perturbations & data
|
274 |
+
blr.fit(current_perturbations[LDATA], current_perturbations[LY], weights)
|
275 |
+
|
276 |
+
# Get set of perturbations to consider labeling
|
277 |
+
candidate_perturbations = self._lime_tabular_perturb_n_samples(data, to_consider)
|
278 |
+
_, var = blr.predict(candidate_perturbations[LDATA])
|
279 |
+
|
280 |
+
# Reweight
|
281 |
+
var /= tempurature
|
282 |
+
exp_var = np.exp(var)
|
283 |
+
all_exp = np.sum(exp_var)
|
284 |
+
tempurature_scaled_weights = exp_var / all_exp
|
285 |
+
|
286 |
+
# Get sampled indices
|
287 |
+
least_confident_sample = np.random.choice(len(var), size=current_batch_size, p=tempurature_scaled_weights, replace=False)
|
288 |
+
|
289 |
+
# Get predictions
|
290 |
+
cy = classifier_f(candidate_perturbations[LINVERSE][least_confident_sample])[:, label]
|
291 |
+
|
292 |
+
new_perturbations = self._select_indices_from_data(candidate_perturbations, least_confident_sample, cy)
|
293 |
+
current_perturbations = self._stack_tabular_return(current_perturbations, new_perturbations)
|
294 |
+
current_n_perturbations += new_perturbations[0].shape[0]
|
295 |
+
|
296 |
+
if self.verbose:
|
297 |
+
pbar.update(new_perturbations[0].shape[0])
|
298 |
+
|
299 |
+
return current_perturbations
|
300 |
+
|
301 |
+
def _lime_kernel(self, d, kernel_width):
|
302 |
+
return np.sqrt(np.exp(-(d ** 2) / kernel_width ** 2))
|
303 |
+
|
304 |
+
def _explain_bayes_lime(self,
|
305 |
+
data,
|
306 |
+
classifier_f,
|
307 |
+
label,
|
308 |
+
focus_sample,
|
309 |
+
cred_width,
|
310 |
+
n_samples,
|
311 |
+
max_n_samples,
|
312 |
+
focus_sample_batch_size,
|
313 |
+
focus_sample_initial_points,
|
314 |
+
ptg_initial_points,
|
315 |
+
to_consider):
|
316 |
+
"""Computes the bayeslime tabular explanations."""
|
317 |
+
|
318 |
+
# Case where only n_samples is specified and not focused sampling
|
319 |
+
if n_samples is not None and not focus_sample:
|
320 |
+
logging.info("Generating bayeslime explanation with %s samples", n_samples)
|
321 |
+
|
322 |
+
# Generate perturbations
|
323 |
+
rdata, inverse, scaled_data, distances = self._lime_tabular_perturb_n_samples(data, n_samples)
|
324 |
+
weights = self._lime_kernel(distances, self.width)
|
325 |
+
y = classifier_f(inverse)[:, label]
|
326 |
+
blr = BayesianLinearRegression(percent=self.cred_int)
|
327 |
+
blr.fit(rdata, y, weights)
|
328 |
+
# Focus sampling
|
329 |
+
elif focus_sample:
|
330 |
+
logging.info("Starting focused sampling")
|
331 |
+
if n_samples:
|
332 |
+
logging.info("n_samples preset, running focused sampling up to %s samples", n_samples)
|
333 |
+
logging.info("using batch size %s with %s initial points", focus_sample_batch_size, focus_sample_initial_points)
|
334 |
+
focused_sampling_output = self.lime_tabular_focus_sample(data,
|
335 |
+
classifier_f,
|
336 |
+
label,
|
337 |
+
n_samples,
|
338 |
+
focus_sample_batch_size,
|
339 |
+
focus_sample_initial_points,
|
340 |
+
to_consider=to_consider,
|
341 |
+
existing_data=[])
|
342 |
+
rdata = focused_sampling_output[LDATA]
|
343 |
+
distances = focused_sampling_output[LDISTANCES]
|
344 |
+
y = focused_sampling_output[LY]
|
345 |
+
|
346 |
+
blr = BayesianLinearRegression(percent=self.cred_int)
|
347 |
+
weights = self._lime_kernel(distances, self.width)
|
348 |
+
blr.fit(rdata, y, weights)
|
349 |
+
else:
|
350 |
+
# Use ptg to get the number of samples, then focus sample
|
351 |
+
# Note, this isn't used in the paper, this case currently isn't implemented
|
352 |
+
raise NotImplementedError
|
353 |
+
|
354 |
+
else:
|
355 |
+
# PTG Step 1, get initial
|
356 |
+
rdata, inverse, scaled_data, distances = self._lime_tabular_perturb_n_samples(data, ptg_initial_points)
|
357 |
+
weights = self._lime_kernel(distances, self.width)
|
358 |
+
y = classifier_f(inverse)[:, label]
|
359 |
+
blr = BayesianLinearRegression(percent=self.cred_int)
|
360 |
+
blr.fit(rdata, y, weights)
|
361 |
+
|
362 |
+
# PTG Step 2, get additional points needed
|
363 |
+
n_needed = int(np.ceil(blr.get_ptg(cred_width)))
|
364 |
+
if self.verbose:
|
365 |
+
tqdm.write(f"Additional Number of perturbations needed is {n_needed}")
|
366 |
+
ptg_rdata, ptg_inverse, ptg_scaled_data, ptg_distances = self._lime_tabular_perturb_n_samples(data, n_needed - ptg_initial_points)
|
367 |
+
ptg_weights = self._lime_kernel(ptg_distances, self.width)
|
368 |
+
|
369 |
+
rdata = np.concatenate((rdata, ptg_rdata), axis=0)
|
370 |
+
inverse = np.concatenate((inverse, ptg_inverse), axis=0)
|
371 |
+
scaled_data = np.concatenate((scaled_data, ptg_scaled_data), axis=0)
|
372 |
+
distances = np.concatenate((distances, ptg_distances), axis=0)
|
373 |
+
|
374 |
+
# Run final model
|
375 |
+
ptgy = classifier_f(ptg_inverse)[:, label]
|
376 |
+
y = np.concatenate((y, ptgy), axis=0)
|
377 |
+
blr = BayesianLinearRegression(percent=self.cred_int)
|
378 |
+
blr.fit(rdata, y, self._lime_kernel(distances, self.width))
|
379 |
+
|
380 |
+
# Format output for returning
|
381 |
+
output = {
|
382 |
+
"data": rdata,
|
383 |
+
"y": y,
|
384 |
+
"distances": distances,
|
385 |
+
"blr": blr,
|
386 |
+
"coef": blr.coef_,
|
387 |
+
"max_coefs": None # Included for consistency purposes w/ bayesshap
|
388 |
+
}
|
389 |
+
|
390 |
+
return output
|
391 |
+
|
392 |
+
def _get_shap_weights(self, data, M):
|
393 |
+
"""Gets shap weights. This assumes data is binary."""
|
394 |
+
nonzero = np.count_nonzero(data, axis=1)
|
395 |
+
weights = []
|
396 |
+
for nz in nonzero:
|
397 |
+
denom = (nCk(M, nz) * nz * (M - nz))
|
398 |
+
# Stabilize kernel
|
399 |
+
if denom == 0:
|
400 |
+
weight = 1.0
|
401 |
+
else:
|
402 |
+
weight = ((M - 1) / denom)
|
403 |
+
weights.append(weight)
|
404 |
+
return weights
|
405 |
+
|
406 |
+
def _enumerate_initial_shap(self, data, max_coefs=None):
|
407 |
+
"""Enumerate 1's for stability."""
|
408 |
+
if max_coefs is None:
|
409 |
+
data = np.eye(data.shape[0])
|
410 |
+
inverse = self.shap_info.discretizer.undiscretize(data)
|
411 |
+
return data, inverse
|
412 |
+
else:
|
413 |
+
data = np.zeros((max_coefs.shape[0], data.shape[0]))
|
414 |
+
for i in range(max_coefs.shape[0]):
|
415 |
+
data[i, max_coefs[i]] = 1
|
416 |
+
inverse = self.shap_info.discretizer.undiscretize(data)
|
417 |
+
return data[:, max_coefs], inverse
|
418 |
+
|
419 |
+
def _explain_bayes_shap(self,
|
420 |
+
data,
|
421 |
+
classifier_f,
|
422 |
+
label,
|
423 |
+
focus_sample,
|
424 |
+
cred_width,
|
425 |
+
n_samples,
|
426 |
+
max_n_samples,
|
427 |
+
focus_sample_batch_size,
|
428 |
+
focus_sample_initial_points,
|
429 |
+
ptg_initial_points,
|
430 |
+
to_consider,
|
431 |
+
feature_select_num_points=1_000,
|
432 |
+
n_features=10,
|
433 |
+
l2=True,
|
434 |
+
enumerate_initial=True,
|
435 |
+
feature_selection=True,
|
436 |
+
max_coefs=None):
|
437 |
+
"""Computes the bayesshap tabular explanations."""
|
438 |
+
if feature_selection and max_coefs is None:
|
439 |
+
n_features = min(n_features, data.shape[0])
|
440 |
+
_, feature_select_inverse = self._shap_tabular_perturb_n_samples(data, feature_select_num_points)
|
441 |
+
lr = Ridge().fit(feature_select_inverse, classifier_f(feature_select_inverse)[:, label])
|
442 |
+
max_coefs = np.argsort(np.abs(lr.coef_))[-1 * n_features:]
|
443 |
+
elif feature_selection and max_coefs is not None:
|
444 |
+
pass
|
445 |
+
else:
|
446 |
+
max_coefs = None
|
447 |
+
|
448 |
+
# Case without focused sampling
|
449 |
+
if n_samples is not None and not focus_sample:
|
450 |
+
logging.info("Generating bayesshap explanation with %s samples", n_samples)
|
451 |
+
|
452 |
+
# Enumerate single coalitions, if requested
|
453 |
+
if enumerate_initial:
|
454 |
+
data_init, inverse_init = self._enumerate_initial_shap(data, max_coefs)
|
455 |
+
n_more = n_samples - inverse_init.shape[0]
|
456 |
+
else:
|
457 |
+
n_more = n_samples
|
458 |
+
|
459 |
+
rdata, inverse = self._shap_tabular_perturb_n_samples(data, n_more, max_coefs)
|
460 |
+
|
461 |
+
if enumerate_initial:
|
462 |
+
rdata = np.concatenate((data_init, rdata), axis=0)
|
463 |
+
inverse = np.concatenate((inverse_init, inverse), axis=0)
|
464 |
+
|
465 |
+
y = classifier_f(inverse)[:, label]
|
466 |
+
weights = self._get_shap_weights(rdata, M=rdata.shape[1])
|
467 |
+
|
468 |
+
blr = BayesianLinearRegression(percent=self.cred_int)
|
469 |
+
blr.fit(rdata, y, weights)
|
470 |
+
elif focus_sample:
|
471 |
+
if feature_selection:
|
472 |
+
raise NotImplementedError
|
473 |
+
|
474 |
+
logging.info("Starting focused sampling")
|
475 |
+
if n_samples:
|
476 |
+
logging.info("n_samples preset, running focused sampling up to %s samples", n_samples)
|
477 |
+
logging.info("using batch size %s with %s initial points", focus_sample_batch_size, focus_sample_initial_points)
|
478 |
+
focused_sampling_output = self.shap_tabular_focus_sample(data,
|
479 |
+
classifier_f,
|
480 |
+
label,
|
481 |
+
n_samples,
|
482 |
+
focus_sample_batch_size,
|
483 |
+
focus_sample_initial_points,
|
484 |
+
to_consider=to_consider,
|
485 |
+
enumerate_initial=enumerate_initial)
|
486 |
+
rdata = focused_sampling_output[SDATA]
|
487 |
+
y = focused_sampling_output[SY]
|
488 |
+
weights = self._get_shap_weights(rdata, rdata.shape[1])
|
489 |
+
blr = BayesianLinearRegression(percent=self.cred_int, l2=l2)
|
490 |
+
blr.fit(rdata, y, weights)
|
491 |
+
else:
|
492 |
+
# Use ptg to get the number of samples, then focus sample
|
493 |
+
# Note, this case isn't used in the paper and currently isn't implemented
|
494 |
+
raise NotImplementedError
|
495 |
+
else:
|
496 |
+
# Use PTG to get initial samples
|
497 |
+
|
498 |
+
# Enumerate intial points if requested
|
499 |
+
if enumerate_initial:
|
500 |
+
data_init, inverse_init = self._enumerate_initial_shap(data, max_coefs)
|
501 |
+
n_more = ptg_initial_points - inverse_init.shape[0]
|
502 |
+
else:
|
503 |
+
n_more = ptg_initial_points
|
504 |
+
|
505 |
+
# Perturb using initial samples
|
506 |
+
rdata, inverse = self._shap_tabular_perturb_n_samples(data, n_more, max_coefs)
|
507 |
+
if enumerate_initial:
|
508 |
+
rdata = np.concatenate((data_init, rdata), axis=0)
|
509 |
+
inverse = np.concatenate((inverse_init, inverse), axis=0)
|
510 |
+
|
511 |
+
# Get labels
|
512 |
+
y = classifier_f(inverse)[:, label]
|
513 |
+
|
514 |
+
# Fit BLR
|
515 |
+
weights = self._get_shap_weights(rdata, M=rdata.shape[1])
|
516 |
+
blr = BayesianLinearRegression(percent=self.cred_int, l2=l2)
|
517 |
+
blr.fit(rdata, y, weights)
|
518 |
+
|
519 |
+
# Compute PTG number needed
|
520 |
+
n_needed = int(np.ceil(blr.get_ptg(cred_width)))
|
521 |
+
ptg_rdata, ptg_inverse = self._shap_tabular_perturb_n_samples(data,
|
522 |
+
n_needed - ptg_initial_points,
|
523 |
+
max_coefs)
|
524 |
+
|
525 |
+
if self.verbose:
|
526 |
+
tqdm.write(f"{n_needed} more samples needed")
|
527 |
+
|
528 |
+
rdata = np.concatenate((rdata, ptg_rdata), axis=0)
|
529 |
+
inverse = np.concatenate((inverse, ptg_inverse), axis=0)
|
530 |
+
ptgy = classifier_f(ptg_inverse)[:, label]
|
531 |
+
weights = self._get_shap_weights(rdata, M=rdata.shape[1])
|
532 |
+
|
533 |
+
# Run final model
|
534 |
+
ptgy = classifier_f(ptg_inverse)[:, label]
|
535 |
+
y = np.concatenate((y, ptgy), axis=0)
|
536 |
+
blr = BayesianLinearRegression(percent=self.cred_int, l2=l2)
|
537 |
+
blr.fit(rdata, y, weights)
|
538 |
+
|
539 |
+
# Format output for returning
|
540 |
+
output = {
|
541 |
+
"data": rdata,
|
542 |
+
"y": y,
|
543 |
+
"distances": weights,
|
544 |
+
"blr": blr,
|
545 |
+
"coef": blr.coef_,
|
546 |
+
"max_coefs": max_coefs
|
547 |
+
}
|
548 |
+
|
549 |
+
return output
|
550 |
+
|
551 |
+
def explain(self,
|
552 |
+
data,
|
553 |
+
classifier_f,
|
554 |
+
label,
|
555 |
+
cred_width=1e-2,
|
556 |
+
focus_sample=True,
|
557 |
+
n_samples=None,
|
558 |
+
max_n_samples=10_000,
|
559 |
+
focus_sample_batch_size=2_500,
|
560 |
+
focus_sample_initial_points=100,
|
561 |
+
ptg_initial_points=200,
|
562 |
+
to_consider=10_000,
|
563 |
+
feature_selection=True,
|
564 |
+
n_features=15,
|
565 |
+
tag=None,
|
566 |
+
only_coef=False,
|
567 |
+
only_blr=False,
|
568 |
+
enumerate_initial=True,
|
569 |
+
max_coefs=None,
|
570 |
+
l2=True):
|
571 |
+
"""Explain an instance.
|
572 |
+
|
573 |
+
As opposed to other model agnostic explanations, the bayes explanations
|
574 |
+
accept a credible interval width instead of a number of perturbations
|
575 |
+
value.
|
576 |
+
|
577 |
+
If the credible interval is set to 95% (as is the default), the bayesian
|
578 |
+
explanations will generate feature importances that are +/- width/2
|
579 |
+
95% of the time.
|
580 |
+
|
581 |
+
|
582 |
+
Arguments:
|
583 |
+
data: The data instance to explain
|
584 |
+
classifier_f: The classification function. This function should return
|
585 |
+
probabilities for each label, where if there are M labels
|
586 |
+
and N instances, the output is of shape (N, M).
|
587 |
+
label: The label index to explain.
|
588 |
+
cred_width: The width of the credible interval of the resulting explanation. Note,
|
589 |
+
this serves as a upper bound in the implementation, the final credible
|
590 |
+
intervals may be tighter, because PTG is a bit approximate. Also, be
|
591 |
+
aware that for kernelshap, if we can compute the kernelshap values exactly
|
592 |
+
by enumerating all the coalitions.
|
593 |
+
focus_sample: Whether to use uncertainty sampling.
|
594 |
+
n_samples: If specified, n_samples with override the width setting feature
|
595 |
+
and compute the explanation with n_samples.
|
596 |
+
max_n_samples: The maximum number of samples to use. If the width is set to
|
597 |
+
a very small value and many samples are required, this serves
|
598 |
+
as a point to stop sampling.
|
599 |
+
focus_sample_batch_size: The batch size of focus sampling.
|
600 |
+
focus_sample_initial_points: The number of perturbations to collect before starting
|
601 |
+
focused sampling.
|
602 |
+
ptg_initial_points: The number perturbations to collect before computing the ptg estimate.
|
603 |
+
to_consider: The number of perturbations to consider in focused sampling.
|
604 |
+
feature_selection: Whether to do feature selection using Ridge regression. Note, currently
|
605 |
+
only implemented for BayesSHAP.
|
606 |
+
n_features: The number of features to use in feature selection.
|
607 |
+
tag: A tag to add the explanation.
|
608 |
+
only_coef: Only return the explanation means.
|
609 |
+
only_blr: Only return the bayesian regression object.
|
610 |
+
enumerate_initial: Whether to enumerate a set of initial shap coalitions.
|
611 |
+
l2: Whether to fit with l2 regression. Turning off the l2 regression can be useful for the shapley value estimation.
|
612 |
+
Returns:
|
613 |
+
explanation: The resulting feature importances, credible intervals, and bayes regression
|
614 |
+
object.
|
615 |
+
"""
|
616 |
+
assert isinstance(data, np.ndarray), "Data must be numpy array. Note, this means that classifier_f \
|
617 |
+
must accept numpy arrays."
|
618 |
+
self._log_args(locals())
|
619 |
+
|
620 |
+
if self.kernel == "lime" and self.data in ["tabular", "image"]:
|
621 |
+
output = self._explain_bayes_lime(data,
|
622 |
+
classifier_f,
|
623 |
+
label,
|
624 |
+
focus_sample,
|
625 |
+
cred_width,
|
626 |
+
n_samples,
|
627 |
+
max_n_samples,
|
628 |
+
focus_sample_batch_size,
|
629 |
+
focus_sample_initial_points,
|
630 |
+
ptg_initial_points,
|
631 |
+
to_consider)
|
632 |
+
elif self.kernel == "shap" and self.data in ["tabular", "image"]:
|
633 |
+
output = self._explain_bayes_shap(data,
|
634 |
+
classifier_f,
|
635 |
+
label,
|
636 |
+
focus_sample,
|
637 |
+
cred_width,
|
638 |
+
n_samples,
|
639 |
+
max_n_samples,
|
640 |
+
focus_sample_batch_size,
|
641 |
+
focus_sample_initial_points,
|
642 |
+
ptg_initial_points,
|
643 |
+
to_consider,
|
644 |
+
feature_selection=feature_selection,
|
645 |
+
n_features=n_features,
|
646 |
+
enumerate_initial=enumerate_initial,
|
647 |
+
max_coefs=max_coefs,
|
648 |
+
l2=l2)
|
649 |
+
else:
|
650 |
+
pass
|
651 |
+
|
652 |
+
output['tag'] = tag
|
653 |
+
|
654 |
+
if only_coef:
|
655 |
+
return output['coef']
|
656 |
+
|
657 |
+
if only_blr:
|
658 |
+
return output['blr']
|
659 |
+
|
660 |
+
return output
|
661 |
+
|
662 |
+
|
663 |
+
def nCk(n, r):
|
664 |
+
"""n choose r
|
665 |
+
|
666 |
+
From: https://stackoverflow.com/questions/4941753/is-there-a-math-ncr-function-in-python"""
|
667 |
+
r = min(r, n-r)
|
668 |
+
numer = reduce(op.mul, range(n, n-r, -1), 1)
|
669 |
+
denom = reduce(op.mul, range(1, r+1), 1)
|
670 |
+
return numer / denom
|
671 |
+
|
672 |
+
|
673 |
+
def do_exp(args):
|
674 |
+
"""Supporting function for the explanations."""
|
675 |
+
i, data, init_kwargs, exp_kwargs, labels, max_coefs, pass_args = args
|
676 |
+
def do(data_i, label):
|
677 |
+
|
678 |
+
if pass_args is not None and pass_args.balance_background_dataset:
|
679 |
+
init_kwargs['training_data'] = np.concatenate((data_i[None, :], np.zeros((1, data_i.shape[0]))), axis=0)
|
680 |
+
|
681 |
+
exp = BayesLocalExplanations(**init_kwargs)
|
682 |
+
exp_kwargs['tag'] = i
|
683 |
+
exp_kwargs['label'] = label
|
684 |
+
if max_coefs is not None:
|
685 |
+
exp_kwargs['max_coefs'] = max_coefs[i]
|
686 |
+
e = deepcopy(exp.explain(data_i, **exp_kwargs))
|
687 |
+
return e
|
688 |
+
if labels is not None:
|
689 |
+
return do(data[i], labels[i])
|
690 |
+
else:
|
691 |
+
return do(data[i], exp_kwargs['label'])
|
692 |
+
|
693 |
+
|
694 |
+
def explain_many(all_data, init_kwargs, exp_kwargs, pool_size=1, verbose=False, labels=None, max_coefs=None, args=None):
|
695 |
+
"""Parallel explanations."""
|
696 |
+
with Pool(pool_size) as p:
|
697 |
+
if verbose:
|
698 |
+
results = list(tqdm(p.imap(do_exp, [(i, all_data, init_kwargs, exp_kwargs, labels, max_coefs, args) for i in range(all_data.shape[0])])))
|
699 |
+
else:
|
700 |
+
results = p.map(do_exp, [(i, all_data, init_kwargs, exp_kwargs, labels, max_coefs, args) for i in range(all_data.shape[0])])
|
701 |
+
return results
|
bayes/models.py
ADDED
@@ -0,0 +1,163 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Routines that implement processing data & getting models.
|
2 |
+
|
3 |
+
This file includes various routines for processing & acquiring models, for
|
4 |
+
later use in the code. The table data preprocessing is straightforward. We
|
5 |
+
first applying scaling to the data and fit a random forest classifier.
|
6 |
+
|
7 |
+
The processing of the image data is a bit more complex. To simplify the construction
|
8 |
+
of the explanations, the explanations don't accept images. Instead, for image explanations,
|
9 |
+
it is necessary to define a function that accept a array of 0's and 1's corresponding to
|
10 |
+
segments for a particular image being either excluded or included respectively. The explanation
|
11 |
+
is performed on this array.
|
12 |
+
"""
|
13 |
+
import numpy as np
|
14 |
+
from copy import deepcopy
|
15 |
+
|
16 |
+
from sklearn.ensemble import RandomForestClassifier
|
17 |
+
from sklearn.preprocessing import StandardScaler
|
18 |
+
from sklearn.model_selection import train_test_split
|
19 |
+
|
20 |
+
import torch
|
21 |
+
from torchvision import models, transforms
|
22 |
+
|
23 |
+
from data.mnist.mnist_model import Net
|
24 |
+
|
25 |
+
def get_xtrain(segs):
|
26 |
+
"""A function to get the mock training data to use in the image explanations.
|
27 |
+
|
28 |
+
This function returns a dataset containing a single instance of ones and
|
29 |
+
another of zeros to represent the training data for the explanation. The idea
|
30 |
+
is that the explanation will use this data to compute the perturbations, which
|
31 |
+
will then be fed into the wrapped model.
|
32 |
+
|
33 |
+
Arguments:
|
34 |
+
segs: The current segments array
|
35 |
+
"""
|
36 |
+
n_segs = len(np.unique(segs))
|
37 |
+
xtrain = np.concatenate((np.ones((1, n_segs)), np.zeros((1, n_segs))), axis=0)
|
38 |
+
return xtrain
|
39 |
+
|
40 |
+
def process_imagenet_get_model(data):
|
41 |
+
"""Gets wrapped imagenet model."""
|
42 |
+
|
43 |
+
# Get the vgg16 model, used in the experiments
|
44 |
+
model = models.vgg16(pretrained=True)
|
45 |
+
model.eval()
|
46 |
+
# model.cuda()
|
47 |
+
|
48 |
+
xtest = data['X']
|
49 |
+
ytest = data['y'].astype(int)
|
50 |
+
xtest_segs = data['segments']
|
51 |
+
|
52 |
+
softmax = torch.nn.Softmax(dim=1)
|
53 |
+
|
54 |
+
# Transforms
|
55 |
+
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
|
56 |
+
std=[0.229, 0.224, 0.225])
|
57 |
+
transf = transforms.Compose([
|
58 |
+
transforms.ToTensor(),
|
59 |
+
normalize
|
60 |
+
])
|
61 |
+
|
62 |
+
t_xtest = transf(xtest[0])[None, :]#.cuda()
|
63 |
+
|
64 |
+
# Define the wrapped model
|
65 |
+
def get_wrapped_model(instance, segments, background=0, batch_size=64):
|
66 |
+
def wrapped_model(data):
|
67 |
+
perturbed_images = []
|
68 |
+
for d in data:
|
69 |
+
perturbed_image = deepcopy(instance)
|
70 |
+
for i, is_on in enumerate(d):
|
71 |
+
if is_on == 0:
|
72 |
+
perturbed_image[segments==i, 0] = background
|
73 |
+
perturbed_image[segments==i, 1] = background
|
74 |
+
perturbed_image[segments==i, 2] = background
|
75 |
+
perturbed_images.append(transf(perturbed_image)[None, :])
|
76 |
+
perturbed_images = torch.from_numpy(np.concatenate(perturbed_images, axis=0)).float()
|
77 |
+
predictions = []
|
78 |
+
for q in range(0, perturbed_images.shape[0], batch_size):
|
79 |
+
predictions.append(softmax(model(perturbed_images[q:q+batch_size])).cpu().detach().numpy())
|
80 |
+
predictions = np.concatenate(predictions, axis=0)
|
81 |
+
return predictions
|
82 |
+
return wrapped_model
|
83 |
+
|
84 |
+
output = {
|
85 |
+
"model": get_wrapped_model,
|
86 |
+
"xtest": xtest,
|
87 |
+
"ytest": ytest,
|
88 |
+
"xtest_segs": xtest_segs,
|
89 |
+
"label": data['y'][0]
|
90 |
+
}
|
91 |
+
|
92 |
+
return output
|
93 |
+
|
94 |
+
def process_mnist_get_model(data):
|
95 |
+
"""Gets wrapped mnist model."""
|
96 |
+
xtest = data['X']
|
97 |
+
ytest = data['y'].astype(int)
|
98 |
+
xtest_segs = data['segments']
|
99 |
+
|
100 |
+
model = Net()
|
101 |
+
model.load_state_dict(torch.load("../data/mnist/mnist_cnn.pt"))
|
102 |
+
model.eval()
|
103 |
+
model.cuda()
|
104 |
+
|
105 |
+
softmax = torch.nn.Softmax(dim=1)
|
106 |
+
def get_wrapped_model(instance, segments, background=-0.4242, batch_size=100):
|
107 |
+
def wrapped_model(data):
|
108 |
+
perturbed_images = []
|
109 |
+
data = torch.from_numpy(data).float().cuda()
|
110 |
+
for d in data:
|
111 |
+
perturbed_image = deepcopy(instance)
|
112 |
+
for i, is_on in enumerate(d):
|
113 |
+
if is_on == 0:
|
114 |
+
a = segments==i
|
115 |
+
perturbed_image[0, segments[0]==i] = background
|
116 |
+
perturbed_images.append(perturbed_image[:, None])
|
117 |
+
perturbed_images = torch.from_numpy(np.concatenate(perturbed_images, axis=0)).float().cuda()
|
118 |
+
|
119 |
+
# Batch predictions if necessary
|
120 |
+
if perturbed_images.shape[0] > batch_size:
|
121 |
+
predictions = []
|
122 |
+
for q in range(0, perturbed_images.shape[0], batch_size):
|
123 |
+
predictions.append(softmax(model(perturbed_images[q:q+batch_size])).cpu().detach().numpy())
|
124 |
+
predictions = np.concatenate(predictions, axis=0)
|
125 |
+
else:
|
126 |
+
predictions = softmax(model(perturbed_images)).cpu().detach().numpy()
|
127 |
+
return np.array(predictions)
|
128 |
+
return wrapped_model
|
129 |
+
|
130 |
+
output = {
|
131 |
+
"model": get_wrapped_model,
|
132 |
+
"xtest": xtest,
|
133 |
+
"ytest": ytest,
|
134 |
+
"xtest_segs": xtest_segs,
|
135 |
+
"label": data['y'][0],
|
136 |
+
}
|
137 |
+
|
138 |
+
return output
|
139 |
+
|
140 |
+
def process_tabular_data_get_model(data):
|
141 |
+
"""Processes tabular data + trains random forest classifier."""
|
142 |
+
X = data['X']
|
143 |
+
y = data['y']
|
144 |
+
|
145 |
+
xtrain,xtest,ytrain,ytest = train_test_split(X,y,test_size=0.2)
|
146 |
+
ss = StandardScaler().fit(xtrain)
|
147 |
+
xtrain = ss.transform(xtrain)
|
148 |
+
xtest = ss.transform(xtest)
|
149 |
+
rf = RandomForestClassifier(n_estimators=100).fit(xtrain,ytrain)
|
150 |
+
|
151 |
+
output = {
|
152 |
+
"model": rf,
|
153 |
+
"xtrain": xtrain,
|
154 |
+
"xtest": xtest,
|
155 |
+
"ytrain": ytrain,
|
156 |
+
"ytest": ytest,
|
157 |
+
"label": 1,
|
158 |
+
"model_score": rf.score(xtest, ytest)
|
159 |
+
}
|
160 |
+
|
161 |
+
print(f"Model Score: {output['model_score']}")
|
162 |
+
|
163 |
+
return output
|
bayes/regression.py
ADDED
@@ -0,0 +1,148 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Bayesian regression.
|
2 |
+
|
3 |
+
A class the implements the Bayesian Regression.
|
4 |
+
"""
|
5 |
+
import operator as op
|
6 |
+
from functools import reduce
|
7 |
+
import copy
|
8 |
+
import collections
|
9 |
+
|
10 |
+
import numpy as np
|
11 |
+
from scipy.stats import invgamma
|
12 |
+
from scipy.stats import multivariate_normal
|
13 |
+
|
14 |
+
class BayesianLinearRegression:
|
15 |
+
def __init__(self, percent=95, l2=True, prior=None):
|
16 |
+
if prior is not None:
|
17 |
+
raise NameError("Currently only support uninformative prior, set to None plz.")
|
18 |
+
|
19 |
+
self.percent = percent
|
20 |
+
self.l2 = l2
|
21 |
+
|
22 |
+
def fit(self, xtrain, ytrain, sample_weight, compute_creds=True):
|
23 |
+
"""
|
24 |
+
Fit the bayesian linear regression.
|
25 |
+
|
26 |
+
Arguments:
|
27 |
+
xtrain: the training data
|
28 |
+
ytrain: the training labels
|
29 |
+
sample_weight: the weights for fitting the regression
|
30 |
+
"""
|
31 |
+
|
32 |
+
# store weights
|
33 |
+
weights = sample_weight
|
34 |
+
|
35 |
+
# add intercept
|
36 |
+
xtrain = np.concatenate((np.ones(xtrain.shape[0])[:,None], xtrain), axis=1)
|
37 |
+
diag_pi_z = np.zeros((len(weights), len(weights)))
|
38 |
+
np.fill_diagonal(diag_pi_z, weights)
|
39 |
+
|
40 |
+
if self.l2:
|
41 |
+
V_Phi = np.linalg.inv(xtrain.transpose().dot(diag_pi_z).dot(xtrain) \
|
42 |
+
+ np.eye(xtrain.shape[1]))
|
43 |
+
else:
|
44 |
+
V_Phi = np.linalg.inv(xtrain.transpose().dot(diag_pi_z).dot(xtrain))
|
45 |
+
|
46 |
+
Phi_hat = V_Phi.dot(xtrain.transpose()).dot(diag_pi_z).dot(ytrain)
|
47 |
+
|
48 |
+
N = xtrain.shape[0]
|
49 |
+
Y_m_Phi_hat = ytrain - xtrain.dot(Phi_hat)
|
50 |
+
|
51 |
+
s_2 = (1.0 / N) * (Y_m_Phi_hat.dot(diag_pi_z).dot(Y_m_Phi_hat) \
|
52 |
+
+ Phi_hat.transpose().dot(Phi_hat))
|
53 |
+
|
54 |
+
self.score = s_2
|
55 |
+
|
56 |
+
self.s_2 = s_2
|
57 |
+
self.N = N
|
58 |
+
self.V_Phi = V_Phi
|
59 |
+
self.Phi_hat = Phi_hat
|
60 |
+
self.coef_ = Phi_hat[1:]
|
61 |
+
self.intercept_ = Phi_hat[0]
|
62 |
+
self.weights = weights
|
63 |
+
|
64 |
+
if compute_creds:
|
65 |
+
self.creds = self.get_creds(percent=self.percent)
|
66 |
+
else:
|
67 |
+
self.creds = "NA"
|
68 |
+
|
69 |
+
self.crit_params = {
|
70 |
+
"s_2": self.s_2,
|
71 |
+
"N": self.N,
|
72 |
+
"V_Phi": self.V_Phi,
|
73 |
+
"Phi_hat": self.Phi_hat,
|
74 |
+
"creds": self.creds
|
75 |
+
}
|
76 |
+
|
77 |
+
return self
|
78 |
+
|
79 |
+
def predict(self, data):
|
80 |
+
"""
|
81 |
+
The predictive distribution.
|
82 |
+
|
83 |
+
Arguments:
|
84 |
+
data: The data to predict
|
85 |
+
"""
|
86 |
+
q_1 = np.eye(data.shape[0])
|
87 |
+
data_ones = np.concatenate((np.ones(data.shape[0])[:,None], data), axis=1)
|
88 |
+
|
89 |
+
# Get response
|
90 |
+
response = np.matmul(data, self.coef_)
|
91 |
+
response += self.intercept_
|
92 |
+
|
93 |
+
# Compute var
|
94 |
+
temp = np.matmul(data_ones, self.V_Phi)
|
95 |
+
mat = np.matmul(temp, data_ones.transpose())
|
96 |
+
var = self.s_2 * (q_1 + mat)
|
97 |
+
diag = np.diagonal(var)
|
98 |
+
|
99 |
+
return response, np.sqrt(diag)
|
100 |
+
|
101 |
+
def get_ptg(self, desired_width):
|
102 |
+
"""
|
103 |
+
Compute the ptg perturbations.
|
104 |
+
"""
|
105 |
+
cert = (desired_width / 1.96) ** 2
|
106 |
+
S = self.coef_.shape[0] * self.s_2
|
107 |
+
T = np.mean(self.weights)
|
108 |
+
return 4 * S / (self.coef_.shape[0] * T * cert)
|
109 |
+
|
110 |
+
def get_creds(self, percent=95, n_samples=10_000, get_intercept=False):
|
111 |
+
"""
|
112 |
+
Get the credible intervals.
|
113 |
+
|
114 |
+
Arguments:
|
115 |
+
percent: the percent cutoff for the credible interval, i.e., 95 is 95% credible interval
|
116 |
+
n_samples: the number of samples to compute the credible interval
|
117 |
+
get_intercept: whether to include the intercept in the credible interval
|
118 |
+
"""
|
119 |
+
samples = self.draw_posterior_samples(n_samples, get_intercept=get_intercept)
|
120 |
+
creds = np.percentile(np.abs(samples - (self.Phi_hat if get_intercept else self.coef_)),
|
121 |
+
percent,
|
122 |
+
axis=0)
|
123 |
+
return creds
|
124 |
+
|
125 |
+
def draw_posterior_samples(self, num_samples, get_intercept=False):
|
126 |
+
"""
|
127 |
+
Sample from the posterior.
|
128 |
+
|
129 |
+
Arguments:
|
130 |
+
num_samples: number of samples to draw from the posterior
|
131 |
+
get_intercept: whether to include the intercept
|
132 |
+
"""
|
133 |
+
|
134 |
+
sigma_2 = invgamma.rvs(self.N / 2, scale=(self.N * self.s_2) / 2, size=num_samples)
|
135 |
+
|
136 |
+
phi_samples = []
|
137 |
+
for sig in sigma_2:
|
138 |
+
sample = multivariate_normal.rvs(mean=self.Phi_hat,
|
139 |
+
cov=self.V_Phi * sig,
|
140 |
+
size=1)
|
141 |
+
phi_samples.append(sample)
|
142 |
+
|
143 |
+
phi_samples = np.vstack(phi_samples)
|
144 |
+
|
145 |
+
if get_intercept:
|
146 |
+
return phi_samples
|
147 |
+
else:
|
148 |
+
return phi_samples[:, 1:]
|
requirements.txt
CHANGED
@@ -6,7 +6,6 @@ astor
|
|
6 |
astunparse
|
7 |
attrs
|
8 |
backcall
|
9 |
-
bayes
|
10 |
beautifulsoup4
|
11 |
BHClustering
|
12 |
bleach
|
|
|
6 |
astunparse
|
7 |
attrs
|
8 |
backcall
|
|
|
9 |
beautifulsoup4
|
10 |
BHClustering
|
11 |
bleach
|