import networkx as nx import numpy as np from cdlib import algorithms # these functions are heavily influenced by the HF squad_metrics.py script def normalize_text(s): """Removing articles and punctuation, and standardizing whitespace are all typical text processing steps.""" import string, re def remove_articles(text): regex = re.compile(r"\b(a|an|the)\b", re.UNICODE) return re.sub(regex, " ", text) def white_space_fix(text): return " ".join(text.split()) def remove_punc(text): exclude = set(string.punctuation) return "".join(ch for ch in text if ch not in exclude) def lower(text): return text.lower() return white_space_fix(remove_articles(remove_punc(lower(s)))) def compute_exact_match(prediction, truth): return int(normalize_text(prediction) == normalize_text(truth)) def compute_f1(prediction, truth): pred_tokens = normalize_text(prediction).split() truth_tokens = normalize_text(truth).split() # if either the prediction or the truth is no-answer then f1 = 1 if they agree, 0 otherwise if len(pred_tokens) == 0 or len(truth_tokens) == 0: return int(pred_tokens == truth_tokens) common_tokens = set(pred_tokens) & set(truth_tokens) # if there are no common tokens then f1 = 0 if len(common_tokens) == 0: return 0 prec = len(common_tokens) / len(pred_tokens) rec = len(common_tokens) / len(truth_tokens) return 2 * (prec * rec) / (prec + rec) def is_date_or_num(answer): answer = answer.lower().split() for w in answer: w = w.strip() if w.isnumeric() or w in ["ngày", "tháng", "năm"]: return True return False def find_best_cluster(answers, best_answer, thr=0.79): if len(answers) == 0: # or best_answer not in answers: return best_answer elif len(answers) == 1: return answers[0] dists = np.zeros((len(answers), len(answers))) for i in range(len(answers) - 1): for j in range(i + 1, len(answers)): a1 = answers[i].lower().strip() a2 = answers[j].lower().strip() if is_date_or_num(a1) or is_date_or_num(a2): # print(a1, a2) if a1 == a2 or ("tháng" in a1 and a1 in a2) or ("tháng" in a2 and a2 in a1): dists[i, j] = 1 dists[j, i] = 1 # continue elif a1 == a2 or (a1 in a2) or (a2 in a1) or compute_f1(a1.lower(), a2.lower()) >= thr: dists[i, j] = 1 dists[j, i] = 1 # print(dists) try: thr = 1 dups = np.where(dists >= thr) dup_strs = [] edges = [] for i, j in zip(dups[0], dups[1]): if i != j: edges.append((i, j)) G = nx.Graph() for i, answer in enumerate(answers): G.add_node(i, content=answer) G.add_edges_from(edges) partition = algorithms.louvain(G) max_len_comm = np.max([len(x) for x in partition.communities]) best_comms = [] for comm in partition.communities: # print([answers[i] for i in comm]) if len(comm) == max_len_comm: best_comms.append([answers[i] for i in comm]) # if len(best_comms) > 1: # return best_answer for comm in best_comms: if best_answer in comm: return best_answer mid = len(best_comms[0]) // 2 # print(mid, sorted(best_comms[0], key = len)) return sorted(best_comms[0], key=len)[mid] except Exception as e: print(e, "Disconnected graph") return best_answer