Source code for ynlu.sdk.evaluation.plot.confusion_matrix_plot

import itertools
from typing import List, Tuple

import numpy as np
import matplotlib.pyplot as plt

from .utils import plt_set_font_style


def _check_input_data_format(
        confusion_matrix: np.ndarray,
        indices: List[str],
    ) -> None:
    if confusion_matrix.shape[0] != confusion_matrix.shape[1]:
        raise ValueError(
            "confusion matrix should be a squared matrix.",
        )
    if confusion_matrix.shape[0] != len(indices):
        raise ValueError(
            """
            The shape of confusion matrix should be the same as
            the length of indices.
            """,
        )


[docs]def plot_confusion_matrix( confusion_matrix: np.ndarray, indices: List[str], normalize: bool = False, output_path: str = None, title: str = "Confusion Matrix", figure_size: Tuple[int, int] = (8, 6), cmap: plt.cm = plt.cm.Blues, font_style_path: str = None, dpi: int = 300, block: bool = True, ) -> None: """Plot confusion matrix Args: confusion_matrix (a squared numpy matrix): The outcome of ``sklearn.metrics.confusion_matrix`` is expected. indices (list of string): List of strings to index the matrix. The expected indices should be the same as the input argument ``labels`` when you generate confusion matrix. normalize (bool): If True, .. math:: \\text{confusion matrix}_{ij} = \\frac{\\text{confusion matrix}_{ij}}{\sum_{j=1}^{n}\\text{confusion matrix}_{ij}} output_path (string, default = None): The place where the output figure would be stored. If it is None, the figure will be shown on screen automatically. title (string, default = "Confusion Matrix"): The title of the figure. figure_size (a pair of integers, default = (8, 6)): The height and width of the output figure. cmap (color map): Matplotlib built-in colormaps. font_style_path (path of font style): If None, ``simhei.ttf`` will be used as default font style. Chinese characters are supported in this font style. dpi (int, default = 300): The resolution in dots per inch. block (bool): If False, the figure will not be shown up even if output_path is None. This argument is left for unittest. Returns: None Example: >>> import numpy as np >>> from ynlu.sdk.evaluation.plot import plot_confusion_matrix >>> plot_confusion_matrix( confusion_matrix=np.random.rand(4,4), indices=["a", "b", "c", "d"], ) """ # noqa _check_input_data_format( confusion_matrix=confusion_matrix, indices=indices, ) plt.figure(figsize=figure_size, dpi=dpi) if normalize: confusion_matrix = ( confusion_matrix.astype(float) / confusion_matrix.sum(axis=1)[:, np.newaxis] ) plt.imshow( confusion_matrix, interpolation='nearest', cmap=cmap, ) plt.title(title) plt.colorbar() tick_marks = np.arange(len(indices)) plt.xticks(tick_marks, indices, rotation=45) plt.yticks(tick_marks, indices) fmt = '.2f' if normalize else 'd' thresh = confusion_matrix.max() / 2. for i, j in itertools.product( range(confusion_matrix.shape[0]), range(confusion_matrix.shape[1]), ): plt.text(j, i, format(confusion_matrix[i, j], fmt), horizontalalignment="center", color="white" if confusion_matrix[i, j] > thresh else "black") plt.tight_layout() plt.ylabel('True Label') plt.xlabel('Predicted Label') plt_set_font_style(font_style_path=font_style_path) if output_path is not None: plt.savefig(output_path) print("Saving confusion matrix figure to {}".format(output_path)) else: plt.show(block=block)