Skip to content

thoth.handler

base_handler

BaseHandler

Bases: ABC

Abstract base class to handle article specific elements of app

Source code in thoth/handler/base_handler.py
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
class BaseHandler(ABC):
    """Abstract base class to handle article specific elements of app"""

    ARTICLE_TITLE: ClassVar[str]
    """The formatted name of the article"""

    DATASETS: ClassVar[List[str]]
    """List of dataset names available in the article"""

    SUMMARY: ClassVar[pd.DataFrame]
    """A set of values defining the properties of the ML method"""

    NAME: ClassVar[str]
    """The programatic name of the handler"""

    def __init__(self) -> None:
        super().__init__()
        self.dataset: Dict[str, Any]
        self.data: pd.DataFrame
        self.train_x: np.ndarray
        self.test_x: np.ndarray
        self.train_y: np.ndarray
        self.test_y: np.ndarray
        self.text_path = Path(__file__).parent.parent.joinpath(
            "static", "text", self.NAME
        )

    def __init_subclass__(cls) -> None:
        if not inspect.isabstract(cls):
            HANDLER_REGISTRY[cls.ARTICLE_TITLE] = cls

        return super().__init_subclass__()

    def render_page(self) -> None:
        """Main method for rendering the entire page"""

        st.title(self.ARTICLE_TITLE)

        self.render_summary()
        with st.expander("Introduction", expanded=True):
            st.write(self.get_section("intro"), unsafe_allow_html=True)

        self.render_eda()
        self.render_playground()

    @st.cache(show_spinner=False)
    def get_section(self, section: str) -> str:
        """Retrieves the contents of a markdown file and returns them as a string

        Each article has the article text stored in markdown files. These are located
        in `text/<article_name>/<section>.md`

        Args:
            section (str): The name of the section to retrieve the markdown for

        Returns:
            The markdown for the required section
        """
        with open(f"{self.text_path}/{section}.md", "r") as file:
            return file.read()

    def render_summary(self) -> None:
        """Create and render a chart showing basic qualities of the handler's ML method"""
        chart = (
            alt.Chart(self.SUMMARY)
            .mark_bar()
            .encode(
                y="Attribute:N",
                x="Score:Q",
                color=alt.Color("Attribute", legend=None),
                tooltip=["Attribute", "Score"],
            )
            .properties(title=f"{self.ARTICLE_TITLE} as a Machine Learning Model")
        )
        st.altair_chart(chart, use_container_width=True)

    @abstractmethod
    def render_playground(self) -> None:
        """Generates and renders the interactive playground for the handler's ML method

        The playground consists of two sections. The first involves choosing the parameters
        of the model, while the second presents relevant plots and metrics.
        """
        raise NotImplementedError

    def render_eda(self, dataset_index: Optional[int] = None) -> None:
        """Generate and render the data selection and exploration section of the article

        Each handler defines some datasets to choose from, and this function renders these options,
        and displays some interactive graphs to explore the data.

        Args:
            dataset_index: If supplied, specifies the index of the default dataset.
        """
        # * Dataset Selection
        st.header("Data Selection and Exploration")
        st.write(self.get_section("eda"))
        dataset_name = st.selectbox(
            "Choose a Dataset", self.DATASETS, index=dataset_index or 0
        )

        with st.spinner("Loading dataset"):
            self.dataset, self.data = utils.load_process_data(dataset_name)

        self.train_x, self.test_x, self.train_y, self.test_y = train_test_split(
            self.data.drop("label", axis=1),
            self.data["label"],
            train_size=0.8,
            stratify=self.data["label"],
            shuffle=True,
            random_state=SEED,
        )

        # Optionally display dataset information
        with st.expander("Dataset details"):
            st.write(self.dataset["DESCR"])
        st.write(self.data)

        # * EDA
        st.subheader("Simple Exploratory Data Analysis (EDA)")

        # Class Balance
        class_chart = (
            alt.Chart(self.data)
            .mark_bar()
            .encode(
                y=alt.Y("label", axis=alt.Axis(title="Class")),
                x=alt.X("count()", axis=alt.Axis(title="Count")),
                color=alt.Color("label", legend=None),
                tooltip=["label", "count()"],
            )
            .properties(title="Class Distribution")
        )
        st.altair_chart(class_chart, use_container_width=True)

        feature = st.selectbox("Feature", self.data.drop("label", axis=1).columns)

        buffer = 0.1 * (max(self.data[feature]) - min(self.data[feature]))
        density_chart = (
            alt.Chart(self.data)
            .transform_density(
                density=feature,
                groupby=["label"],  # type: ignore
                steps=1000,  # type: ignore
                extent=[
                    min(self.data[feature]) - buffer,
                    max(self.data[feature]) + buffer,
                ],  # type: ignore
            )
            .mark_area()
            .encode(
                alt.X("value:Q", axis=alt.Axis(title=f"{feature}")),
                alt.Y("density:Q", axis=alt.Axis(title="Density")),
                color=alt.Color(
                    "label", legend=alt.Legend(orient="bottom", title="Class")
                ),
                opacity=alt.OpacityValue(0.8),
                tooltip=["label", "density:Q"],
            )
            .properties(title=f"Distribution of {feature} for each class")
        )
        st.altair_chart(density_chart, use_container_width=True)

ARTICLE_TITLE: ClassVar[str] = None class-attribute

The formatted name of the article

DATASETS: ClassVar[List[str]] = None class-attribute

List of dataset names available in the article

NAME: ClassVar[str] = None class-attribute

The programatic name of the handler

SUMMARY: ClassVar[pd.DataFrame] = None class-attribute

A set of values defining the properties of the ML method

get_section(section)

Retrieves the contents of a markdown file and returns them as a string

Each article has the article text stored in markdown files. These are located in text/<article_name>/<section>.md

Parameters:

Name Type Description Default
section str

The name of the section to retrieve the markdown for

required

Returns:

Type Description
str

The markdown for the required section

Source code in thoth/handler/base_handler.py
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
@st.cache(show_spinner=False)
def get_section(self, section: str) -> str:
    """Retrieves the contents of a markdown file and returns them as a string

    Each article has the article text stored in markdown files. These are located
    in `text/<article_name>/<section>.md`

    Args:
        section (str): The name of the section to retrieve the markdown for

    Returns:
        The markdown for the required section
    """
    with open(f"{self.text_path}/{section}.md", "r") as file:
        return file.read()

render_eda(dataset_index=None)

Generate and render the data selection and exploration section of the article

Each handler defines some datasets to choose from, and this function renders these options, and displays some interactive graphs to explore the data.

Parameters:

Name Type Description Default
dataset_index Optional[int]

If supplied, specifies the index of the default dataset.

None
Source code in thoth/handler/base_handler.py
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
def render_eda(self, dataset_index: Optional[int] = None) -> None:
    """Generate and render the data selection and exploration section of the article

    Each handler defines some datasets to choose from, and this function renders these options,
    and displays some interactive graphs to explore the data.

    Args:
        dataset_index: If supplied, specifies the index of the default dataset.
    """
    # * Dataset Selection
    st.header("Data Selection and Exploration")
    st.write(self.get_section("eda"))
    dataset_name = st.selectbox(
        "Choose a Dataset", self.DATASETS, index=dataset_index or 0
    )

    with st.spinner("Loading dataset"):
        self.dataset, self.data = utils.load_process_data(dataset_name)

    self.train_x, self.test_x, self.train_y, self.test_y = train_test_split(
        self.data.drop("label", axis=1),
        self.data["label"],
        train_size=0.8,
        stratify=self.data["label"],
        shuffle=True,
        random_state=SEED,
    )

    # Optionally display dataset information
    with st.expander("Dataset details"):
        st.write(self.dataset["DESCR"])
    st.write(self.data)

    # * EDA
    st.subheader("Simple Exploratory Data Analysis (EDA)")

    # Class Balance
    class_chart = (
        alt.Chart(self.data)
        .mark_bar()
        .encode(
            y=alt.Y("label", axis=alt.Axis(title="Class")),
            x=alt.X("count()", axis=alt.Axis(title="Count")),
            color=alt.Color("label", legend=None),
            tooltip=["label", "count()"],
        )
        .properties(title="Class Distribution")
    )
    st.altair_chart(class_chart, use_container_width=True)

    feature = st.selectbox("Feature", self.data.drop("label", axis=1).columns)

    buffer = 0.1 * (max(self.data[feature]) - min(self.data[feature]))
    density_chart = (
        alt.Chart(self.data)
        .transform_density(
            density=feature,
            groupby=["label"],  # type: ignore
            steps=1000,  # type: ignore
            extent=[
                min(self.data[feature]) - buffer,
                max(self.data[feature]) + buffer,
            ],  # type: ignore
        )
        .mark_area()
        .encode(
            alt.X("value:Q", axis=alt.Axis(title=f"{feature}")),
            alt.Y("density:Q", axis=alt.Axis(title="Density")),
            color=alt.Color(
                "label", legend=alt.Legend(orient="bottom", title="Class")
            ),
            opacity=alt.OpacityValue(0.8),
            tooltip=["label", "density:Q"],
        )
        .properties(title=f"Distribution of {feature} for each class")
    )
    st.altair_chart(density_chart, use_container_width=True)

render_page()

Main method for rendering the entire page

Source code in thoth/handler/base_handler.py
50
51
52
53
54
55
56
57
58
59
60
def render_page(self) -> None:
    """Main method for rendering the entire page"""

    st.title(self.ARTICLE_TITLE)

    self.render_summary()
    with st.expander("Introduction", expanded=True):
        st.write(self.get_section("intro"), unsafe_allow_html=True)

    self.render_eda()
    self.render_playground()

render_playground() abstractmethod

Generates and renders the interactive playground for the handler's ML method

The playground consists of two sections. The first involves choosing the parameters of the model, while the second presents relevant plots and metrics.

Source code in thoth/handler/base_handler.py
 93
 94
 95
 96
 97
 98
 99
100
@abstractmethod
def render_playground(self) -> None:
    """Generates and renders the interactive playground for the handler's ML method

    The playground consists of two sections. The first involves choosing the parameters
    of the model, while the second presents relevant plots and metrics.
    """
    raise NotImplementedError

render_summary()

Create and render a chart showing basic qualities of the handler's ML method

Source code in thoth/handler/base_handler.py
78
79
80
81
82
83
84
85
86
87
88
89
90
91
def render_summary(self) -> None:
    """Create and render a chart showing basic qualities of the handler's ML method"""
    chart = (
        alt.Chart(self.SUMMARY)
        .mark_bar()
        .encode(
            y="Attribute:N",
            x="Score:Q",
            color=alt.Color("Attribute", legend=None),
            tooltip=["Attribute", "Score"],
        )
        .properties(title=f"{self.ARTICLE_TITLE} as a Machine Learning Model")
    )
    st.altair_chart(chart, use_container_width=True)

dt

DTHandler

Bases: BaseHandler

Page handler for the Decision Tree article (short name 'dt')

Source code in thoth/handler/dt.py
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
class DTHandler(BaseHandler):
    """Page handler for the Decision Tree article (short name 'dt')"""

    ARTICLE_TITLE = "Decision Trees"
    DATASETS = ["Breast Cancer", "Iris", "Wine"]
    SUMMARY = pd.DataFrame(
        {
            "Attribute": ["Power", "Interpretability", "Simplicity"],
            "Score": [2, 5, 4],
        },
    )
    NAME = "dt"

    def render_eda(self, dataset_index: Optional[int] = None) -> None:
        if dataset_index is None:
            dataset_index = self.DATASETS.index("Iris")
        return super().render_eda(dataset_index=dataset_index)

    def render_playground(self) -> None:
        st.header("Model Playground")
        st.write(self.get_section("playground"))
        st.subheader("Parameter Selection")

        if any(
            data is None
            for data in (self.train_x, self.train_y, self.test_x, self.test_y)
        ):
            raise ValueError(
                "A dataset must be chosen before the playground can be rendered!"
            )

        params: Dict[str, Any] = {
            "random_state": SEED,
            "criterion": st.selectbox(
                "Splitting criterion:", ["gini", "entropy"], index=1
            ),
            "max_depth": st.slider(
                "Maximum tree depth:", min_value=1, max_value=10, value=5
            ),
            "min_samples_split": st.slider(
                "Minimum number of samples required to split",
                min_value=2,
                max_value=len(self.train_x),
                step=1,
            ),
        }

        with st.expander("Advanced parameters"):
            params["splitter"] = st.selectbox(
                "How to select feature to split by:", ["best", "random"]
            )
            params["min_impurity_decrease"] = st.slider(
                f"Minimum decrease in {params['criterion']} required to perform a split:",
                min_value=0.0,
                max_value=0.5,
                step=0.001,
                format="%.3f",
            )
            if st.checkbox("Balance classes inversely proportional to their frequency"):
                params["class_weight"] = "balanced"
            params["max_features"] = st.slider(
                "Number of features to consider at each split (randomly selected at each branch):",
                min_value=1,
                max_value=len(self.dataset["feature_names"]),
                value=len(self.dataset["feature_names"]),
            )

        decision_tree = utils.train_model(
            DecisionTreeClassifier, params, self.train_x, self.train_y
        )

        train_metrics = utils.get_metrics(decision_tree, self.train_x, self.train_y)
        train_metrics = train_metrics.set_axis(["Train"], axis="index")
        test_metrics = utils.get_metrics(decision_tree, self.test_x, self.test_y)
        test_metrics = test_metrics.set_axis(["Test"], axis="index")

        st.subheader("Performance Metrics")
        st.write(train_metrics.append(test_metrics))

        st.subheader("View Tree")
        with st.spinner("Plotting tree"):
            tree_dot = export_graphviz(
                decision_tree,
                out_file=None,
                rounded=True,
                filled=True,
                class_names=self.dataset["target_names"],
                feature_names=self.dataset["feature_names"],
            )
            st.graphviz_chart(tree_dot, use_container_width=True)

        st.subheader("Tree Parameters")
        st.write(decision_tree.get_params())