Source code for ynlu.sdk.evaluation.plot.line_plot

from typing import List, Tuple

import seaborn as sns
import matplotlib.pyplot as plt

from .utils import plt_set_font_style

def _check_data_format(data: List[dict]):
    for i, datum in enumerate(data):
        if "x" not in datum:
            raise KeyError("No.{}: datum {} has no x.".format(i, datum))
        if "y" not in datum:
            raise KeyError("No.{}: datum {} has no y.".format(i, datum))
        if len(datum["x"]) != len(datum["y"]):
            raise KeyError(
                "No.{}: x, y in datum must have the same length.".format(i))

[docs]def plot_lines( data: List[dict], title: str = "figure", x_axis_name: str = "x", y_axis_name: str = "y", figure_size: Tuple[int, int]=(8, 6), output_path: str = None, color_palette: sns.color_palette = None, font_style_path: str = None, dpi: int = 300, block: bool = True, ) -> None: """Plot y versus x as lines and/or markers Args: data (list of dictionaries): Dictionaries containing arguments and key words for ``matplotlib.pyplot.plot``. Basic arguments are: ``x``, ``y``, ``label`` (the name of line). title (string, default = "figure"): The title of the figure. x_axis_name (string, default = "x"): The name to be shown on x axis. y_axis_name(string, default = "y"): The name to be shown on y axis. figure_size (a pair of integers, default = (8, 6)): The height and width of the output figure. output_path (string, default = None): The place where the output figure would be stored. If it is None, the figure will be shown on screen automatically. color_palette (seaborn color palette object): Please take a look at ```` for more details. font_style_path (path of font style): If None, ``simhei.ttf`` will be used as default font style. Chinese characters are supported in this font style. dpi (int, default = 300): The resolution in dots per inch. block (bool): if False, the figure will not be shown up even if output_path is None. This argument is left for unittest. Returns: None Example: >>> from ynlu.sdk.evaluation.plot import plot_lines >>> plot_lines( data=[ {"x": [1, 2, 3], "y": [4, 5, 6], "label": "line1"}, {"x": [6, 7, 8], "y": [9, 10, 11], "label": "line2"}, ], ) """ _check_data_format(data=data) plt.figure(figsize=figure_size, dpi=dpi) if color_palette is None: color_palette = sns.color_palette("Set2", 10) sns.set_palette(color_palette) default_plot_params = { "linewidth": 2, "alpha": 0.7, "linestyle": "-", "marker": "o", } lines = [] line_names = [] for i, input_datum in enumerate(data): plot_params = default_plot_params plot_params.update(input_datum) del plot_params["x"], plot_params["y"] line, = plt.plot(input_datum["x"], input_datum["y"], **plot_params) lines.append(line) line_names.append(input_datum.get("label", "line_" + str(i))) plt.title(title) plt.xlabel(x_axis_name) plt.ylabel(y_axis_name) plt.legend( handles=lines, labels=line_names, bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0., ) plt.subplots_adjust(right=0.8) plt_set_font_style(font_style_path=font_style_path) if output_path is not None: plt.savefig(output_path) else: