Source code for pytagi.metric
import numpy as np
from pytagi.nn.data_struct import HRCSoftmax
from pytagi.tagi_utils import Utils
[docs]
class HRCSoftmaxMetric:
"""Classification error metric for Hierarchical Softmax.
This class provides methods to compute the error rate and get predicted labels
for a classification model that uses Hierarchical Softmax.
"""
def __init__(self, num_classes: int):
"""Initializes the HRCSoftmaxMetric.
:param num_classes: The total number of classes in the classification problem.
:type num_classes: int
"""
self.num_classes = num_classes
self.utils = Utils()
self.hrc_softmax: HRCSoftmax = self.utils.get_hierarchical_softmax(
num_classes=num_classes
)
[docs]
def error_rate(
self, m_pred: np.ndarray, v_pred: np.ndarray, label: np.ndarray
) -> float:
"""Computes the classification error rate.
This method calculates the proportion of incorrect predictions by comparing
the predicted labels against the true labels.
:param m_pred: The mean of the predictions from the model.
:type m_pred: np.ndarray
:param v_pred: The variance of the predictions from the model.
:type v_pred: np.ndarray
:param label: The ground truth labels.
:type label: np.ndarray
:return: The classification error rate, a value between 0 and 1.
:rtype: float
"""
batch_size = m_pred.shape[0] // self.hrc_softmax.len
pred, _ = self.utils.get_labels(
m_pred, v_pred, self.hrc_softmax, self.num_classes, batch_size
)
return classification_error(pred, label)
[docs]
def get_predicted_labels(
self, m_pred: np.ndarray, v_pred: np.ndarray
) -> np.ndarray:
"""Gets the predicted class labels from the model's output.
:param m_pred: The mean of the predictions from the model.
:type m_pred: np.ndarray
:param v_pred: The variance of the predictions from the model.
:type v_pred: np.ndarray
:return: An array of predicted class labels.
:rtype: np.ndarray
"""
batch_size = m_pred.shape[0] // self.hrc_softmax.len
pred, _ = self.utils.get_labels(
m_pred, v_pred, self.hrc_softmax, self.num_classes, batch_size
)
return pred
[docs]
def mse(prediction: np.ndarray, observation: np.ndarray) -> float:
"""Calculates the Mean Squared Error (MSE).
MSE measures the average of the squares of the errors, i.e., the average
squared difference between the estimated and the observed values.
:param prediction: The predicted values.
:type prediction: np.ndarray
:param observation: The actual (observed) values.
:type observation: np.ndarray
:return: The mean squared error.
:rtype: float
"""
return np.nanmean((prediction - observation) ** 2)
[docs]
def log_likelihood(
prediction: np.ndarray, observation: np.ndarray, std: np.ndarray
) -> float:
"""Computes the log-likelihood.
This function assumes the likelihood of the observation given the prediction
is a Gaussian distribution with a given standard deviation.
:param prediction: The predicted mean of the distribution.
:type prediction: np.ndarray
:param observation: The observed data points.
:type observation: np.ndarray
:param std: The standard deviation of the distribution.
:type std: np.ndarray
:return: The average log-likelihood value.
:rtype: float
"""
log_lik = -0.5 * np.log(2 * np.pi * (std**2)) - 0.5 * (
((observation - prediction) / std) ** 2
)
return np.nanmean(log_lik)
[docs]
def rmse(prediction: np.ndarray, observation: np.ndarray) -> float:
"""Calculates the Root Mean Squared Error (RMSE).
RMSE is the square root of the mean of the squared errors.
:param prediction: The predicted values.
:type prediction: np.ndarray
:param observation: The actual (observed) values.
:type observation: np.ndarray
:return: The root mean squared error.
:rtype: float
"""
mse_val = mse(prediction, observation)
return mse_val**0.5
[docs]
def classification_error(prediction: np.ndarray, label: np.ndarray) -> float:
"""Computes the classification error rate.
This function calculates the fraction of predictions that do not match the
true labels.
:param prediction: An array of predicted labels.
:type prediction: np.ndarray
:param label: An array of true labels.
:type label: np.ndarray
:return: The classification error rate (proportion of incorrect predictions).
:rtype: float
"""
count = 0
for pred, lab in zip(prediction.T, label):
if pred != lab:
count += 1
return count / len(prediction)