srinidhidevaraj commited on
Commit
d7f7f62
β€’
1 Parent(s): 41c1861

Upload 5 files

Browse files
Files changed (5) hide show
  1. helpers.py +214 -0
  2. prompt_template.py +42 -0
  3. requirements.txt +14 -0
  4. run_tree_search.py +174 -0
  5. tree_search_icd.py +47 -0
helpers.py ADDED
@@ -0,0 +1,214 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ import os
3
+ import simple_icd_10_cm as cm
4
+ from transformers import AutoModelForCausalLM, AutoTokenizer
5
+ # from openai import OpenAI
6
+ from prompt_template import *
7
+ from langchain_groq import ChatGroq
8
+ from groq import Groq
9
+ from dotenv import load_dotenv
10
+ import csv
11
+ import time
12
+ load_dotenv()
13
+
14
+ os.environ["LANGCHAIN_TRACING_V2"]="true"
15
+ groq_api_key=os.environ.get('GROQ_API_KEY')
16
+ os.environ["LANGCHAIN_ENDPOINT"]="https://api.smith.langchain.com"
17
+ LANGCHAIN_API_KEY=os.environ.get("LANGCHAIN_API_KEY")
18
+
19
+ client = Groq()
20
+
21
+ CHAPTER_LIST = cm.chapter_list
22
+
23
+ def construct_translation_prompt(medical_note):
24
+ """
25
+ Construct a prompt template for translating spanish medical notes to english.
26
+
27
+ Args:
28
+ medical_note (str): The medical case note.
29
+
30
+ Returns:
31
+ str: A structured template ready to be used as input for a language model.
32
+ """
33
+ translation_prompt = """You are an expert Spanish-to-English translator. You are provided with a clinical note written in Spanish.
34
+ You must translate the note into English. You must ensure that you properly translate the medical and technical terms from Spanish to English without any mistakes.
35
+ Spanish Medical Note:
36
+ {medical_note}"""
37
+
38
+ return translation_prompt.format(medical_note = medical_note)
39
+
40
+ def build_translation_prompt(input_note, system_prompt=""):
41
+ """
42
+ Build a zero-shot prompt for translating spanish medical notes to english.
43
+
44
+ Args:
45
+ input_note (str): The input note or query.
46
+ system_prompt (str): Optional initial system prompt or instruction.
47
+
48
+ Returns:
49
+ list of dict: A structured list of dictionaries defining the role and content of each message.
50
+ """
51
+ input_prompt = construct_translation_prompt(input_note)
52
+
53
+
54
+ return [{"role": "system", "content": system_prompt}, {"role": "user", "content": input_prompt}]
55
+
56
+
57
+ def remove_extra_spaces(text):
58
+ """
59
+ Remove extra spaces from a given text.
60
+
61
+ Args:
62
+ text (str): The original text string.
63
+
64
+ Returns:
65
+ str: The cleaned text with extra spaces removed.
66
+ """
67
+ return re.sub(r'\s+', ' ', text).strip()
68
+
69
+ def remove_last_parenthesis(text):
70
+ """
71
+ Removes the last occurrence of content within parentheses from the provided text.
72
+
73
+ Args:
74
+ text (str): The input string from which to remove the last parentheses and its content.
75
+
76
+ Returns:
77
+ str: The modified string with the last parentheses content removed.
78
+ """
79
+ pattern = r'\([^()]*\)(?!.*\([^()]*\))'
80
+ cleaned_text = re.sub(pattern, '', text)
81
+ return cleaned_text
82
+
83
+ def format_code_descriptions(text, model_name):
84
+ """
85
+ Format the ICD-10 code descriptions by removing content inside brackets and extra spaces.
86
+
87
+ Args:
88
+ text (str): The original text containing ICD-10 code descriptions.
89
+
90
+ Returns:
91
+ str: The cleaned text with content in brackets removed and extra spaces cleaned up.
92
+ """
93
+ pattern = r'\([^()]*\)(?!.*\([^()]*\))'
94
+ cleaned_text = remove_last_parenthesis(text)
95
+ cleaned_text = remove_extra_spaces(cleaned_text)
96
+
97
+ return cleaned_text
98
+
99
+ def construct_prompt_template(case_note, code_descriptions, model_name):
100
+ """
101
+ Construct a prompt template for evaluating ICD-10 code descriptions against a given case note.
102
+
103
+ Args:
104
+ case_note (str): The medical case note.
105
+ code_descriptions (str): The ICD-10 code descriptions formatted as a single string.
106
+
107
+ Returns:
108
+ str: A structured template ready to be used as input for a language model.
109
+ """
110
+ template = prompt_template_dict[model_name]
111
+
112
+ return template.format(note=case_note, code_descriptions=code_descriptions)
113
+
114
+ def build_zero_shot_prompt(input_note, descriptions, model_name, system_prompt=""):
115
+ """
116
+ Build a zero-shot classification prompt with system and user roles for a language model.
117
+
118
+ Args:
119
+ input_note (str): The input note or query.
120
+ descriptions (list of str): List of ICD-10 code descriptions.
121
+ system_prompt (str): Optional initial system prompt or instruction.
122
+
123
+ Returns:
124
+ list of dict: A structured list of dictionaries defining the role and content of each message.
125
+ """
126
+ if model_name == "llama3-70b-8192":
127
+ code_descriptions = "\n".join(["* " + x for x in descriptions])
128
+ else:
129
+
130
+ code_descriptions = "\n".join(["* " + x for x in descriptions])
131
+
132
+
133
+ input_prompt = construct_prompt_template(input_note, code_descriptions, model_name)
134
+ return [{"role": "system", "content": system_prompt}, {"role": "user", "content": input_prompt}]
135
+
136
+ def get_response(messages, model_name, temperature=0.0, max_tokens=500):
137
+ """
138
+ Obtain responses from a specified model via the chat-completions API.
139
+
140
+ Args:
141
+ messages (list of dict): List of messages structured for API input.
142
+ model_name (str): Identifier for the model to query.
143
+ temperature (float): Controls randomness of response, where 0 is deterministic.
144
+ max_tokens (int): Limit on the number of tokens in the response.
145
+
146
+ Returns:
147
+ str: The content of the response message from the model.
148
+ """
149
+ response = client.chat.completions.create(
150
+ model=model_name,
151
+ messages=messages,
152
+ temperature=temperature,
153
+ max_tokens=max_tokens
154
+ )
155
+ return response.choices[0].message.content
156
+
157
+ def remove_noisy_prefix(text):
158
+ # Removing numbers or letters followed by a dot and optional space at the beginning of the string
159
+ cleaned_text = text.replace("* ", "").strip()
160
+ cleaned_text = re.sub(r"^\s*\w+\.\s*", "", cleaned_text)
161
+ return cleaned_text.strip()
162
+ def parse_outputs(output, code_description_map, model_name):
163
+ """
164
+ Parse model outputs to confirm ICD-10 codes based on a given description map.
165
+
166
+ Args:
167
+ output (str): The model output containing confirmations.
168
+ code_description_map (dict): Mapping of descriptions to ICD-10 codes.
169
+
170
+ Returns:
171
+ list of dict: A list of confirmed codes and their descriptions.
172
+ """
173
+ confirmed_codes = []
174
+ split_outputs = [x for x in output.split("\n") if x]
175
+ for item in split_outputs:
176
+ try:
177
+ code_description, confirmation = item.split(":", 1)
178
+ # print(confirmation)
179
+ cnf,fact = confirmation.split(",", 1)
180
+
181
+
182
+ if model_name == "llama3-70b-8192":
183
+ code_description = remove_noisy_prefix(code_description)
184
+ else:
185
+ code_description = remove_noisy_prefix(code_description)
186
+
187
+ if confirmation.lower().strip().startswith("yes"):
188
+ try:
189
+
190
+ code = code_description_map[code_description]
191
+
192
+
193
+ confirmed_codes.append({"ICD Code": code, "Code Description": code_description,"Evidence From Notes":fact})
194
+
195
+ except Exception as e:
196
+ # print(str(e) + " Here")
197
+ continue
198
+ except:
199
+ continue
200
+ return confirmed_codes
201
+
202
+ def get_name_and_description(code, model_name):
203
+ """
204
+ Retrieve the name and description of an ICD-10 code.
205
+
206
+ Args:
207
+ code (str): The ICD-10 code.
208
+
209
+ Returns:
210
+ tuple: A tuple containing the formatted description and the name of the code.
211
+ """
212
+ full_data = cm.get_full_data(code).split("\n")
213
+ return format_code_descriptions(full_data[3], model_name), full_data[1]
214
+
prompt_template.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ prompt_template_dict = {"mixtral-8x7b-32768" : """[Case note]:
2
+ {note}
3
+ [Example]:
4
+ <code descriptions>
5
+ * Gastro-esophageal reflux disease
6
+ * Enteroptosis
7
+ * Acute Nasopharyngitis [Common Cold]
8
+ </code descriptions>
9
+
10
+ <response>
11
+ * Gastro-esophageal reflux disease: Yes,Patient was prescribed omeprazole.
12
+ * Enteroptosis: No.
13
+ * Acute Nasopharyngitis [Common Cold]: No.
14
+ </response>
15
+
16
+ [Task]:
17
+ Follow the format in the example response exactly, including the entire description after your (Yes|No) judgement , followed by a newline.
18
+ Consider each of the following ICD-10 code descriptions and evaluate if there are any related mentions in the Case note.
19
+ {code_descriptions}""",
20
+
21
+ "llama3-70b-8192": """[Case note]:
22
+ {note}
23
+
24
+ [Example]:
25
+ <code descriptions>
26
+ * Gastro-esophageal reflux disease
27
+ * Enteroptosis
28
+ * Acute Nasopharyngitis [Common Cold]
29
+ </code descriptions>
30
+
31
+ <response>
32
+ * Gastro-esophageal reflux disease: Yes,Patient was prescribed omeprazole.
33
+ * Enteroptosis: No.
34
+ * Acute Nasopharyngitis [Common Cold]: No.
35
+ </response>
36
+
37
+ [Task]:
38
+ Follow the format in the example response exactly, including the entire description after your (Yes|No) judgement , followed by a newline.
39
+ Consider each of the following ICD-10 code descriptions and evaluate if there are any related mentions in the Case note.
40
+ {code_descriptions}"""
41
+ }
42
+
requirements.txt ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ streamlit
2
+ python-dotenv
3
+ simple_icd_10_cm
4
+ tqdm
5
+ transformers
6
+ groq
7
+ langchain
8
+ langchain-groq
9
+ langchain-community
10
+ torch
11
+ tensorflow
12
+ flax
13
+ jax
14
+ jaxlib
run_tree_search.py ADDED
@@ -0,0 +1,174 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ import pandas as pd
4
+ import json
5
+ from tree_search_icd import get_icd_codes
6
+ from tqdm import tqdm
7
+ import csv
8
+ import streamlit as st
9
+ import tempfile
10
+ from pathlib import Path
11
+ from io import StringIO
12
+
13
+ # def process_medical_notes(file_path,model_name):
14
+ # def process_medical_notes(input_dir, output_file, model_name):
15
+
16
+ # code_map = {}
17
+
18
+ # if not os.path.isdir(input_dir):
19
+ # raise ValueError("The specified input directory does not exist.")
20
+
21
+ # # Process each file in the input directory
22
+ # for files in tqdm(os.listdir(input_dir)):
23
+ # file_path = os.path.join(input_dir, files)
24
+ # print(file_path)
25
+ # with open(file_path, "r", encoding="utf-8") as file:
26
+ # medical_note = file.read()
27
+
28
+ # if not os.path.isfile(file_path):
29
+ # print(f"File does not exist: {file_path}")
30
+ # return None
31
+
32
+
33
+ # # if os.path.isfile(file_path):
34
+ # # st.write(f"File exists: {file_path}")
35
+
36
+ # # try:
37
+
38
+ # # with open(file_path, "r",encoding="utf-8") as txtfile:
39
+ # # st.write(file_path)
40
+ # # medical_note = txtfile.read()
41
+
42
+ # # st.write(f"Content of the file: {medical_note[:1000]}") # Print the first 1000 characters
43
+ # # except Exception as e:
44
+ # # print(f"Error reading file: {e}")
45
+ # # return None
46
+
47
+ # # print(f"File read successfully. Content length: {len(medical_note)}")
48
+
49
+ # #print(medical_note)
50
+ # icd_codes = get_icd_codes(medical_note, model_name)
51
+ # print(icd_codes)
52
+ # # return icd_codes
53
+ # # print(icd_codes)
54
+ # # code_map[files] = icd_codes
55
+
56
+ # with open(output_file, "w") as f:
57
+ # json.dump(code_map, f, indent=4)
58
+
59
+
60
+ # if __name__ == "__main__":
61
+ # parser = argparse.ArgumentParser(description="Process medical notes to extract ICD codes using a specified model.")
62
+ # parser.add_argument("--input_dir", help="Directory containing the medical text files")
63
+ # parser.add_argument("--output_file", help="File to save the extracted ICD codes in JSON format")
64
+ # parser.add_argument("--model_name", default="llama3-70b-8192", help="Model name to use for ICD code extraction")
65
+
66
+ # args = parser.parse_args()
67
+ # process_medical_notes(args.input_dir, args.output_file, args.model_name)
68
+
69
+ def process_medical_notes(filepath, model_name):
70
+
71
+
72
+ try:
73
+ for txtfile in filepath:
74
+ with open(filepath, "r",encoding="utf-8") as txtfile:
75
+ medical_note = txtfile.read()
76
+
77
+
78
+ except Exception as e:
79
+ # print(f"Error reading file: {e}")
80
+ return None
81
+
82
+
83
+ icd_codes = get_icd_codes(medical_note, model_name)
84
+ return icd_codes
85
+
86
+
87
+
88
+ def add_custom_css():
89
+ st.markdown(
90
+ """
91
+ <style>
92
+ /* Remove padding around the main block */
93
+ .block-container {
94
+ padding: 1rem;
95
+ }
96
+ /* Remove padding around the top */
97
+ header, footer, .reportview-container .main .block-container {
98
+ padding: 5;
99
+ }
100
+ /* Fullscreen layout adjustments */
101
+ .css-1d391kg {
102
+ padding: 5;
103
+ }
104
+
105
+ h1 {
106
+ text-align: center;
107
+ }
108
+ .table-wrapper {
109
+ text-align: center;
110
+ }
111
+
112
+
113
+
114
+
115
+ </style>
116
+ """,
117
+ unsafe_allow_html=True,
118
+ )
119
+ def main():
120
+ st.set_page_config(layout="wide",page_icon='πŸ”Ž',page_title='ICD Identifier')
121
+ add_custom_css()
122
+ st.title("ICD Code Extractor From Medical Notes")
123
+
124
+ col1, col2 = st.columns([1, 5])
125
+ with col2:
126
+
127
+ file_uploads=st.file_uploader('Choose Medical Note File',type='txt', accept_multiple_files=True)
128
+
129
+ submit = st.button("Submit")
130
+
131
+
132
+ with col1:
133
+ model_name = st.selectbox(
134
+ "Select Model",
135
+ ["llama3-70b-8192", "mixtral-8x7b-32768"],
136
+ index=0 # Default model selected
137
+ )
138
+
139
+ if submit :
140
+
141
+ for file_input in file_uploads:
142
+ file_name = Path(file_input.name).name
143
+ with tempfile.NamedTemporaryFile(delete=False, suffix='.txt') as temp_file:
144
+
145
+ temp_file.write(file_input.getbuffer())
146
+ temp_file.flush()
147
+ file_paths = temp_file.name
148
+ response=process_medical_notes(file_paths, model_name)
149
+ res_data=pd.DataFrame(response,columns=['ICD Code','Code Description','Evidence From Notes'])
150
+ with col2:
151
+
152
+
153
+ # st.markdown(f"""
154
+
155
+
156
+ # <div class="custom-table-container" >
157
+ # <h4>Case Id: {file_name}</h4>
158
+
159
+ # <div class="table-wrapper" >
160
+ # {res_data.to_html(classes='table-wrapper', index=False)}
161
+ # </div>
162
+ # </div>
163
+
164
+
165
+ # """, unsafe_allow_html=True)
166
+ st.markdown(f"""
167
+ <h5>Case Id: {file_name}</h5>
168
+ """, unsafe_allow_html=True)
169
+ st.markdown(res_data.style.hide(axis="index").to_html(), unsafe_allow_html=True)
170
+
171
+ # st.write(response)
172
+
173
+ if __name__=="__main__":
174
+ main()
tree_search_icd.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from helpers import *
2
+
3
+ def get_icd_codes(medical_note, model_name, temperature=0.0):
4
+ """
5
+ Identifies relevant ICD-10 codes for a given medical note by querying a language model.
6
+
7
+ This function implements the tree-search algorithm for ICD coding described in https://openreview.net/forum?id=mqnR8rGWkn.
8
+
9
+ Args:
10
+ medical_note (str): The medical note for which ICD-10 codes are to be identified.
11
+ model_name (str): The identifier for the language model used in the API (default is 'gpt-3.5-turbo-0613').
12
+
13
+ Returns:
14
+ list of str: A list of confirmed ICD-10 codes that are relevant to the medical note.
15
+ """
16
+ assigned_codes = []
17
+ candidate_codes = [x.name for x in CHAPTER_LIST]
18
+ parent_codes = []
19
+ prompt_count = 0
20
+
21
+ while prompt_count < 50:
22
+ code_descriptions = {}
23
+ for x in candidate_codes:
24
+ description, code = get_name_and_description(x, model_name)
25
+ code_descriptions[description] = code
26
+
27
+ prompt = build_zero_shot_prompt(medical_note, list(code_descriptions.keys()), model_name=model_name)
28
+ lm_response = get_response(prompt, model_name, temperature=temperature, max_tokens=500)
29
+
30
+ predicted_codes = parse_outputs(lm_response, code_descriptions, model_name=model_name)
31
+
32
+ for code in predicted_codes:
33
+ if cm.is_leaf(code["ICD Code"]):
34
+ # assigned_codes.append(code["code"])
35
+ assigned_codes.append({"ICD Code": code["ICD Code"], "Code Description": code["Code Description"],"Evidence From Notes":code["Evidence From Notes"]})
36
+ else:
37
+ parent_codes.append(code)
38
+
39
+ if len(parent_codes) > 0:
40
+ parent_code = parent_codes.pop(0)
41
+ candidate_codes = cm.get_children(parent_code["ICD Code"])
42
+ else:
43
+ break
44
+
45
+ prompt_count += 1
46
+
47
+ return assigned_codes