leandro commited on
Commit
dc9a7be
1 Parent(s): 1c022e5

add examples

Browse files
Files changed (2) hide show
  1. app.py +7 -4
  2. examples.json +14 -7
app.py CHANGED
@@ -16,7 +16,7 @@ def load_model(model_ckpt):
16
  def load_examples():
17
  with open("examples.json", "r") as f:
18
  examples = json.load(f)
19
- return dict([(x["name"], x["value"]) for x in examples])
20
 
21
  st.set_page_config(page_icon=':parrot:', layout="wide")
22
 
@@ -28,6 +28,8 @@ model_ckpt = "lvwerra/codeparrot"
28
  tokenizer = load_tokenizer(model_ckpt)
29
  model = load_model(model_ckpt)
30
  examples = load_examples()
 
 
31
  set_seed(42)
32
  gen_kwargs = {}
33
 
@@ -36,11 +38,12 @@ st.markdown('##')
36
 
37
  pipe = pipeline('text-generation', model=model, tokenizer=tokenizer)
38
  st.sidebar.header("Examples:")
39
- selected_example = st.sidebar.selectbox("Select one of the following examples:", examples.keys())
40
- example_text = examples[selected_example]
 
41
  st.sidebar.header("Generation settings:")
42
  gen_kwargs["do_sample"] = st.sidebar.radio("Decoding strategy", ["Greedy", "Sample"]) == "Sample"
43
- gen_kwargs["max_new_tokens"] = st.sidebar.slider("Number of tokens to generate", value=32, min_value=8, step=8, max_value=256)
44
  if gen_kwargs["do_sample"]:
45
  gen_kwargs["temperature"] = st.sidebar.slider("Temperature", value = 0.2, min_value = 0.0, max_value=2.0, step=0.05)
46
  gen_kwargs["top_k"] = st.sidebar.slider("Top-k", min_value = 0, max_value=100, value = 0)
 
16
  def load_examples():
17
  with open("examples.json", "r") as f:
18
  examples = json.load(f)
19
+ return examples
20
 
21
  st.set_page_config(page_icon=':parrot:', layout="wide")
22
 
 
28
  tokenizer = load_tokenizer(model_ckpt)
29
  model = load_model(model_ckpt)
30
  examples = load_examples()
31
+ example_names = [example["name"] for example in examples]
32
+ name2id = dict([(name, i) for i, name in enumerate(example_names)])
33
  set_seed(42)
34
  gen_kwargs = {}
35
 
 
38
 
39
  pipe = pipeline('text-generation', model=model, tokenizer=tokenizer)
40
  st.sidebar.header("Examples:")
41
+ selected_example = st.sidebar.selectbox("Select one of the following examples:", example_names)
42
+ example_text = examples[name2id[selected_example]]["value"]
43
+ default_length = examples[name2id[selected_example]]["length"]
44
  st.sidebar.header("Generation settings:")
45
  gen_kwargs["do_sample"] = st.sidebar.radio("Decoding strategy", ["Greedy", "Sample"]) == "Sample"
46
+ gen_kwargs["max_new_tokens"] = st.sidebar.slider("Number of tokens to generate", value=default_length, min_value=8, step=8, max_value=256)
47
  if gen_kwargs["do_sample"]:
48
  gen_kwargs["temperature"] = st.sidebar.slider("Temperature", value = 0.2, min_value = 0.0, max_value=2.0, step=0.05)
49
  gen_kwargs["top_k"] = st.sidebar.slider("Top-k", min_value = 0, max_value=100, value = 0)
examples.json CHANGED
@@ -1,31 +1,38 @@
1
  [
2
  {
3
  "name": "Hello World!",
4
- "value": "def print_hello_world():\n \"\"\"Print 'Hello World!'.\"\"\""
 
5
  },
6
  {
7
  "name": "Filesize",
8
- "value": "def get_file_size(filepath):"
 
9
  },
10
  {
11
  "name": "Python to Numpy",
12
- "value": "# calculate mean in native Python:\ndef mean(a):\n return sum(a)/len(a)\n\n# calculate mean numpy:\nimport numpy as np\n\ndef mean(a):"
 
13
  },
14
  {
15
  "name": "unittest",
16
- "value": "def is_even(value):\n \"\"\"Returns True if value is an even number.\"\"\"\n return value % 2 == 0\n\n# setup unit tests for is_even\nimport unittest"
 
17
 
18
  },
19
  {
20
  "name": "Scikit-Learn",
21
- "value": "import numpy as np\nfrom sklearn.ensemble import RandomForestClassifier\n\n# create training data\nX = np.random.randn(100, 100)\ny = np.random.randint(0, 1, 100)\n\n# setup train test split"
 
22
  },
23
  {
24
  "name": "Pandas",
25
- "value": "# load dataframe from csv\ndf = pd.read_csv(filename)\n\n# columns: \"age_group\", \"income\"\n# calculate average income per age group"
 
26
  },
27
  {
28
  "name": "Transformers",
29
- "value": "from transformers import AutoTokenizer, AutoModelForSequenceClassification\n\n# build a BERT classifier"
 
30
  }
31
  ]
 
1
  [
2
  {
3
  "name": "Hello World!",
4
+ "value": "def print_hello_world():\n \"\"\"Print 'Hello World!'.\"\"\"",
5
+ "length": 8
6
  },
7
  {
8
  "name": "Filesize",
9
+ "value": "def get_file_size(filepath):",
10
+ "length": 64
11
  },
12
  {
13
  "name": "Python to Numpy",
14
+ "value": "# native Python:\ndef mean(a):\n return sum(a)/len(a)\n\n# with numpy:\nimport numpy as np\n\ndef mean(a):",
15
+ "length": 16
16
  },
17
  {
18
  "name": "unittest",
19
+ "value": "def is_even(value):\n \"\"\"Returns True if value is an even number.\"\"\"\n return value % 2 == 0\n\n# setup unit tests for is_even\nimport unittest",
20
+ "length": 64
21
 
22
  },
23
  {
24
  "name": "Scikit-Learn",
25
+ "value": "import numpy as np\nfrom sklearn.ensemble import RandomForestClassifier\n\n# create training data\nX = np.random.randn(100, 100)\ny = np.random.randint(0, 1, 100)\n\n# setup train test split",
26
+ "length": 96
27
  },
28
  {
29
  "name": "Pandas",
30
+ "value": "# load dataframe from csv\ndf = pd.read_csv(filename)\n\n# columns: \"age_group\", \"income\"\n# calculate average income per age group",
31
+ "length": 16
32
  },
33
  {
34
  "name": "Transformers",
35
+ "value": "from transformers import AutoTokenizer, AutoModelForSequenceClassification\n\n# build a BERT classifier",
36
+ "length": 48
37
  }
38
  ]