File size: 6,313 Bytes
174cd37
 
 
 
 
 
 
 
d812385
174cd37
 
67fa189
d812385
 
67fa189
1dfccc3
67fa189
 
174cd37
67fa189
7cb14dd
174cd37
67fa189
 
174cd37
67fa189
 
 
 
 
 
174cd37
67fa189
 
 
 
 
174cd37
 
 
7cb14dd
 
ce217e0
 
174cd37
 
67fa189
 
 
7cb14dd
 
 
67fa189
174cd37
cf6aebf
67fa189
cf6aebf
 
174cd37
 
 
ce217e0
174cd37
 
 
 
 
 
 
 
 
 
 
 
1dfccc3
174cd37
1dfccc3
 
174cd37
1dfccc3
 
 
 
174cd37
 
 
1dfccc3
 
 
 
 
628fe8f
d812385
 
 
 
 
 
 
 
 
 
174cd37
 
d812385
 
174cd37
d812385
 
 
174cd37
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
67fa189
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
import json
import os
import pickle as pkl
import re
import shutil
import string
from collections import Counter
from pathlib import Path

import numpy as np
import torch
from transformers import AutoModel, AutoTokenizer


from pathlib import Path

# Core Application URL
SERVER_URL = "http://localhost:8000/"

# Maximum length for user queries
MAX_USER_QUERY_LEN = 128

# Base Directories
CURRENT_DIR = Path(__file__).parent
DEPLOYMENT_DIR = CURRENT_DIR / "deployment"
DATA_PATH = CURRENT_DIR / "files"

# Deployment Directories
CLIENT_DIR = DEPLOYMENT_DIR / "client_dir"
SERVER_DIR = DEPLOYMENT_DIR / "server_dir"
KEYS_DIR = DEPLOYMENT_DIR / ".fhe_keys"

# All Directories
ALL_DIRS = [KEYS_DIR, CLIENT_DIR, SERVER_DIR]

# Model and Data Files
LOGREG_MODEL_PATH = CURRENT_DIR / "models" / "cml_logreg.model"
ORIGINAL_FILE_PATH = DATA_PATH / "original_document.txt"
ANONYMIZED_FILE_PATH = DATA_PATH / "anonymized_document.txt"
MAPPING_UUID_PATH = DATA_PATH / "original_document_uuid_mapping.json"
MAPPING_ANONYMIZED_SENTENCES_PATH = DATA_PATH / "mapping_clear_to_anonymized.pkl"
MAPPING_ENCRYPTED_SENTENCES_PATH = DATA_PATH / "mapping_clear_to_encrypted.pkl"
MAPPING_DOC_EMBEDDING_PATH = DATA_PATH / "mapping_doc_embedding_path.pkl"

PROMPT_PATH = DATA_PATH / "chatgpt_prompt.txt"


# List of example queries for easy access
DEFAULT_QUERIES = {
    "Example Query 1": "What is the amount of the contract between David and Kate?",
    "Example Query 2": "What's the duration of the contract?",
    "Example Query 3": "Does Kate have an international bank account?",
}

# Load tokenizer and model
TOKENIZER = AutoTokenizer.from_pretrained("obi/deid_roberta_i2b2")
EMBEDDINGS_MODEL = AutoModel.from_pretrained("obi/deid_roberta_i2b2")

PUNCTUATION_LIST = list(string.punctuation)
PUNCTUATION_LIST.remove("%")
PUNCTUATION_LIST.remove("$")
PUNCTUATION_LIST = "".join(PUNCTUATION_LIST) + '°'


def clean_directory() -> None:
    """Clear direcgtories"""

    print("Cleaning...\n")
    for target_dir in ALL_DIRS:
        if os.path.exists(target_dir) and os.path.isdir(target_dir):
            shutil.rmtree(target_dir)
        target_dir.mkdir(exist_ok=True, parents=True)


def get_batch_text_representation(texts, model, tokenizer, batch_size=1):
    """Get mean-pooled representations of given texts in batches."""
    mean_pooled_batch = []
    for i in range(0, len(texts), batch_size):
        batch_texts = texts[i : i + batch_size]
        inputs = tokenizer(batch_texts, return_tensors="pt", padding=True, truncation=True)
        with torch.no_grad():
            outputs = model(**inputs, output_hidden_states=False)
        last_hidden_states = outputs.last_hidden_state
        input_mask_expanded = (
            inputs["attention_mask"].unsqueeze(-1).expand(last_hidden_states.size()).float()
        )
        sum_embeddings = torch.sum(last_hidden_states * input_mask_expanded, 1)
        sum_mask = input_mask_expanded.sum(1)
        mean_pooled = sum_embeddings / sum_mask
        mean_pooled_batch.extend(mean_pooled.cpu().detach().numpy())
    return np.array(mean_pooled_batch)


def is_user_query_valid(user_query: str) -> bool:
    """
    Check if the `user_query` is None and not empty.
    Args:
        user_query (str): The input text to be checked.
    Returns:
        bool: True if the `user_query` is None or empty, False otherwise.
    """
    # If the query is not part of the default queries
    is_default_query = user_query in DEFAULT_QUERIES.values()

    # Check if the query exceeds the length limit
    is_exceeded_max_length = user_query is not None and len(user_query) <= MAX_USER_QUERY_LEN

    return not is_default_query and not is_exceeded_max_length


def compare_texts_ignoring_extra_spaces(original_text, modified_text):
    """Check if the modified_text is identical to the original_text except for additional spaces.

    Args:
        original_text (str): The original text for comparison.
        modified_text (str): The modified text to compare against the original.

    Returns:
        (bool): True if the modified_text is the same as the original_text except for
            additional spaces; False otherwise.
    """
    normalized_original = " ".join(original_text.split())
    normalized_modified = " ".join(modified_text.split())

    return normalized_original == normalized_modified


def is_strict_deletion_only(original_text, modified_text):

    # Define a regex pattern that matches a word character next to a punctuation
    # or a punctuation next to a word character, without a space between them.
    pattern = r"(?<=[\w])(?=[^\w\s])|(?<=[^\w\s])(?=[\w])"

    # Replace instances found by the pattern with a space
    original_text = re.sub(pattern, " ", original_text)
    modified_text = re.sub(pattern, " ", modified_text)

    # Tokenize the texts into words, considering also punctuation
    original_words = Counter(original_text.lower().split())
    modified_words = Counter(modified_text.lower().split())

    base_words = all(item in original_words.keys() for item in modified_words.keys())
    base_count = all(original_words[k] >= v for k, v in modified_words.items())

    return base_words and base_count


def read_txt(file_path):
    """Read text from a file."""
    with open(file_path, "r", encoding="utf-8") as file:
        return file.read()


def write_txt(file_path, data):
    """Write text to a file."""
    with open(file_path, "w", encoding="utf-8") as file:
        file.write(data)


def write_pickle(file_path, data):
    """Save data to a pickle file."""
    with open(file_path, "wb") as f:
        pkl.dump(data, f)


def read_pickle(file_name):
    """Load data from a pickle file."""
    with open(file_name, "rb") as file:
        return pkl.load(file)


def read_json(file_name):
    """Load data from a json file."""
    with open(file_name, "r") as file:
        return json.load(file)


def write_json(file_name, data):
    """Save data to a json file."""
    with open(file_name, "w", encoding="utf-8") as file:
        json.dump(data, file, indent=4, sort_keys=True)


def write_bytes(path, data):
    """Save binary data."""
    with path.open("wb") as f:
        f.write(data)


def read_bytes(path):
    """Load data from a binary file."""
    with path.open("rb") as f:
        return f.read()