Source code for ynlu.sdk.evaluation.entity_overlapping_score

from typing import List

from .utils import preprocess_entity_prediction
from ynlu.sdk.evaluation import NOT_ENTITY


[docs]def single__entity_overlapping_score( utterance: str, entity_prediction: List[dict], y_true: List[str], wrong_penalty_rate: float = 2.0, ) -> float: """Overlapping Score of a Single Utterance Examine the true and predicted entities in character-level. Then, compute a score to represent the overlapping rate between them. Args: utterance (a string): The input when calling `model.predict()`. entity_prediction (a list of dictionaries): Entity part of output when callng `model.predict()` with the above utterance as input. y_true (a list of strings): True entity list of that utterance. wrong_penalty_rate (float, default is 2.0): A penalty score would be given when predicting wrong. Returns: overlapping score (float): If wrong penalty rate is 2.0, then an overlapping score would be ranging from -1 to 1. A score -1 represents all entities in an utterance are mismatched. On the other hand, score 1 means entities in an utterance are perfectly matched. Examples: >>> from ynlu.sdk.evaluation import single__entity_overlapping_score >>> overlapping_score = single__entity_overlapping_score( utterance="I like apple.", entity_prediction=[ {"entity": "DONT_CARE", "value": "I like ", "score": 0.9}, {"entity": "fruit", "value": "apple", "score": 0.8}, {"entity": "drink", "value": ".", "score": 0.3}, ], y_true=[ "DONT_CARE", "DONT_CARE", "DONT_CARE", "DONT_CARE", "DONT_CARE", "DONT_CARE", "DONT_CARE", "fruit", "fruit", "fruit", "fruit", "fruit", "DONT_CARE", ], ) >>> print(overlapping_score) 12 / 13 """ y_pred = preprocess_entity_prediction( utterance=utterance, entity_prediction=entity_prediction, not_entity=NOT_ENTITY, ) if len(y_pred) != len(y_true): raise ValueError( "Entity prediction and label must have same length!!!", ) penalty = 0.0 for token_p, token_t in zip(y_pred, y_true): if token_p == token_t: penalty += 0.0 elif (token_p in [NOT_ENTITY, NOT_ENTITY + "*"]) or ( token_t in [NOT_ENTITY, NOT_ENTITY + "*"]): penalty += 1.0 else: penalty += wrong_penalty_rate overlapping_score = 1 - penalty / len(y_pred) return overlapping_score
[docs]def entity_overlapping_score( utterances: List[str], entity_predictions: List[List[dict]], y_trues: List[List[str]], wrong_penalty_rate: float = 2.0, ) -> float: """Averaged Overlapping Score of all Utterances Please take a look at function **single__entity_overlapping_score** first. This function is JUST a batch version of that. It would send all data to that function, then collect and average the output. """ if len(entity_predictions) != len(y_trues): raise ValueError( "Entity predictions and labels must have same amount!!!", ) overlapping_scores = [] for utt, pred, true in zip(utterances, entity_predictions, y_trues): overlapping_scores.append( single__entity_overlapping_score( utterance=utt, entity_prediction=pred, y_true=true, wrong_penalty_rate=wrong_penalty_rate, ), ) return sum(overlapping_scores) / len(entity_predictions)