Spaces:
Runtime error
Runtime error
Lev McKinney
commited on
Commit
β’
4004daa
1
Parent(s):
c35da92
fixed several bugs in app.py
Browse files- .dockerignore +1 -1
- README.md +0 -1
- app.py +7 -8
.dockerignore
CHANGED
@@ -1,2 +1,2 @@
|
|
1 |
lens
|
2 |
-
.git
|
|
|
1 |
lens
|
2 |
+
.git
|
README.md
CHANGED
@@ -3,7 +3,6 @@ title: Tuned Lens
|
|
3 |
emoji: π
|
4 |
colorFrom: pink
|
5 |
colorTo: blue
|
6 |
-
port: 7860
|
7 |
sdk: docker
|
8 |
pinned: false
|
9 |
license: mit
|
|
|
3 |
emoji: π
|
4 |
colorFrom: pink
|
5 |
colorTo: blue
|
|
|
6 |
sdk: docker
|
7 |
pinned: false
|
8 |
license: mit
|
app.py
CHANGED
@@ -7,7 +7,7 @@ from plotly import graph_objects as go
|
|
7 |
|
8 |
device = torch.device("cpu")
|
9 |
print(f"Using device {device} for inference")
|
10 |
-
model = AutoModelForCausalLM.from_pretrained("EleutherAI/pythia-410m-deduped")
|
11 |
model = model.to(device)
|
12 |
tokenizer = AutoTokenizer.from_pretrained("EleutherAI/pythia-410m-deduped")
|
13 |
tuned_lens = TunedLens.from_model_and_pretrained(
|
@@ -29,19 +29,19 @@ statistic_options_dict = {
|
|
29 |
|
30 |
|
31 |
def make_plot(lens, text, statistic, token_cutoff):
|
32 |
-
input_ids = tokenizer.encode(text
|
33 |
input_ids = [tokenizer.bos_token_id] + input_ids
|
34 |
targets = input_ids[1:] + [tokenizer.eos_token_id]
|
35 |
|
36 |
-
if len(input_ids
|
37 |
return go.Figure(layout=dict(title="Please enter some text."))
|
38 |
|
39 |
if token_cutoff < 1:
|
40 |
return go.Figure(layout=dict(title="Please provide valid token cut off."))
|
41 |
|
42 |
-
start_pos=max(len(input_ids
|
43 |
pred_traj = PredictionTrajectory.from_lens_and_model(
|
44 |
-
lens=lens,
|
45 |
model=model,
|
46 |
input_ids=input_ids,
|
47 |
tokenizer=tokenizer,
|
@@ -49,7 +49,7 @@ def make_plot(lens, text, statistic, token_cutoff):
|
|
49 |
start_pos=start_pos,
|
50 |
)
|
51 |
|
52 |
-
return getattr(pred_traj, statistic)().figure(
|
53 |
title=f"{lens.__class__.__name__} ({model.name_or_path}) {statistic}",
|
54 |
)
|
55 |
|
@@ -114,5 +114,4 @@ with gr.Blocks() as demo:
|
|
114 |
demo.load(make_plot, [lens_options, text, statistic, token_cutoff], plot)
|
115 |
|
116 |
if __name__ == "__main__":
|
117 |
-
demo.launch()
|
118 |
-
|
|
|
7 |
|
8 |
device = torch.device("cpu")
|
9 |
print(f"Using device {device} for inference")
|
10 |
+
model = AutoModelForCausalLM.from_pretrained("EleutherAI/pythia-410m-deduped", torch_dtype="auto")
|
11 |
model = model.to(device)
|
12 |
tokenizer = AutoTokenizer.from_pretrained("EleutherAI/pythia-410m-deduped")
|
13 |
tuned_lens = TunedLens.from_model_and_pretrained(
|
|
|
29 |
|
30 |
|
31 |
def make_plot(lens, text, statistic, token_cutoff):
|
32 |
+
input_ids = tokenizer.encode(text)
|
33 |
input_ids = [tokenizer.bos_token_id] + input_ids
|
34 |
targets = input_ids[1:] + [tokenizer.eos_token_id]
|
35 |
|
36 |
+
if len(input_ids) == 1:
|
37 |
return go.Figure(layout=dict(title="Please enter some text."))
|
38 |
|
39 |
if token_cutoff < 1:
|
40 |
return go.Figure(layout=dict(title="Please provide valid token cut off."))
|
41 |
|
42 |
+
start_pos=max(len(input_ids) - token_cutoff, 0)
|
43 |
pred_traj = PredictionTrajectory.from_lens_and_model(
|
44 |
+
lens=lens_options_dict[lens],
|
45 |
model=model,
|
46 |
input_ids=input_ids,
|
47 |
tokenizer=tokenizer,
|
|
|
49 |
start_pos=start_pos,
|
50 |
)
|
51 |
|
52 |
+
return getattr(pred_traj, statistic_options_dict[statistic])().figure(
|
53 |
title=f"{lens.__class__.__name__} ({model.name_or_path}) {statistic}",
|
54 |
)
|
55 |
|
|
|
114 |
demo.load(make_plot, [lens_options, text, statistic, token_cutoff], plot)
|
115 |
|
116 |
if __name__ == "__main__":
|
117 |
+
demo.launch(server_name="0.0.0.0", server_port=7860)
|
|