vulnerability_analysisT / appStore /vulnerability_analysis.py
TeresaK's picture
Upload 38 files
a5e9cde
raw
history blame
2.72 kB
# set path
import glob, os, sys;
sys.path.append('../utils')
#import needed libraries
import seaborn as sns
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import streamlit as st
from utils.vulnerability_classifier import load_vulnerabilityClassifier, vulnerability_classification
import logging
logger = logging.getLogger(__name__)
from utils.config import get_classifier_params
from utils.preprocessing import paraLengthCheck
from io import BytesIO
import xlsxwriter
import plotly.express as px
# Declare all the necessary variables
classifier_identifier = 'vulnerability'
params = get_classifier_params(classifier_identifier)
@st.cache_data
def to_excel(df,sectorlist):
len_df = len(df)
output = BytesIO()
writer = pd.ExcelWriter(output, engine='xlsxwriter')
df.to_excel(writer, index=False, sheet_name='Sheet1')
workbook = writer.book
worksheet = writer.sheets['Sheet1']
worksheet.data_validation('S2:S{}'.format(len_df),
{'validate': 'list',
'source': ['No', 'Yes', 'Discard']})
worksheet.data_validation('X2:X{}'.format(len_df),
{'validate': 'list',
'source': sectorlist + ['Blank']})
worksheet.data_validation('T2:T{}'.format(len_df),
{'validate': 'list',
'source': sectorlist + ['Blank']})
worksheet.data_validation('U2:U{}'.format(len_df),
{'validate': 'list',
'source': sectorlist + ['Blank']})
worksheet.data_validation('V2:V{}'.format(len_df),
{'validate': 'list',
'source': sectorlist + ['Blank']})
worksheet.data_validation('W2:U{}'.format(len_df),
{'validate': 'list',
'source': sectorlist + ['Blank']})
writer.save()
processed_data = output.getvalue()
return processed_data
def app():
with st.container():
if 'combined_files_df' in st.session_state:
combined_files_df = st.session_state['combined_files_df']
classifier = load_vulnerabilityClassifier(classifier_name=params['model_name'])
st.session_state['{}_classifier'.format(classifier_identifier)] = classifier
combined_files_df = vulnerability_classification(haystack_doc=combined_files_df,
threshold=params['threshold'])
st.session_state['combined_files_df'] = combined_files_df