File size: 3,847 Bytes
597bf7d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a2351d6
597bf7d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
import pandas as pd
import streamlit as st

from subpages.page import Context, Page
from utils import colorize_classes, get_bg_color, get_fg_color, htmlify_labeled_example


class LossySamplesPage(Page):
    name = "Samples by Loss"
    icon = "sort-numeric-down-alt"

    def get_widget_defaults(self):
        return {
            "skip_correct": True,
            "samples_by_loss_show_df": True,
        }

    def render(self, context: Context):
        st.title(self.name)
        with st.expander("💡", expanded=True):
            st.write("Show every example sorted by loss (descending) for close inspection.")

        st.subheader("💥 Samples ⬇loss")
        skip_correct = st.checkbox("Skip correct examples", value=True, key="skip_correct")
        show_df = st.checkbox("Show dataframes", key="samples_by_loss_show_df")

        st.write(
            """<style>
thead {
    display: none;
}
td {
    white-space: nowrap;
    padding: 0 5px !important;
}
</style>""",
            unsafe_allow_html=True,
        )

        top_indices = (
            context.df.sort_values(by="total_loss", ascending=False)
            .query("total_loss > 0.5")
            .index
        )

        cnt = 0
        for idx in top_indices:
            sample = context.df_tokens_merged.loc[idx]

            if isinstance(sample, pd.Series):
                continue

            if skip_correct and sum(sample.labels != sample.preds) == 0:
                continue

            if show_df:

                def colorize_col(col):
                    if col.name == "labels" or col.name == "preds":
                        bgs = []
                        fgs = []
                        ops = []
                        for v in col.values:
                            bgs.append(get_bg_color(v.split("-")[1]) if "-" in v else "#ffffff")
                            fgs.append(get_fg_color(bgs[-1]))
                            ops.append("1" if v.split("-")[0] == "B" or v == "O" else "0.5")
                        return [
                            f"background-color: {bg}; color: {fg}; opacity: {op};"
                            for bg, fg, op in zip(bgs, fgs, ops)
                        ]
                    return [""] * len(col)

                df = sample.reset_index().drop(["index", "hidden_states", "ids"], axis=1).round(3)
                losses_slice = pd.IndexSlice["losses", :]
                # x = df.T.astype(str)
                # st.dataframe(x)
                # st.dataframe(x.loc[losses_slice])
                styler = (
                    df.T.style.apply(colorize_col, axis=1)
                    .bar(subset=losses_slice, axis=1)
                    .format(precision=3)
                )
                # styler.data = styler.data.astype(str)
                st.write(styler.to_html(), unsafe_allow_html=True)
                st.write("")
                # st.dataframe(colorize_classes(sample.drop("hidden_states", axis=1)))#.bar(subset='losses'))  # type: ignore
                # st.write(
                #     colorize_errors(sample.round(3).drop("hidden_states", axis=1).astype(str))
                # )

            col1, _, col2 = st.columns([3.5 / 32, 0.5 / 32, 28 / 32])

            cnt += 1
            counter = f"<span title='#sample | index' style='display: block; background-color: black; opacity: 1; color: white; padding: 0 5px'>[{cnt} | {idx}]</span>"
            loss = f"<span title='total loss' style='display: block; background-color: yellow; color: gray; padding: 0 5px;'>𝐿 {sample.losses.sum():.3f}</span>"
            col1.write(f"{counter}{loss}", unsafe_allow_html=True)
            col1.write("")

            col2.write(htmlify_labeled_example(sample), unsafe_allow_html=True)
            # st.write(f"[{i};{idx}] " + htmlify_corr_sample(sample), unsafe_allow_html=True)