Source code for smote_variants.noise_removal._condensednn

"""
This module implements the condensed nearest neighbors technique.
"""

import numpy as np

from sklearn.neighbors import KNeighborsClassifier

from ._noisefilter import NoiseFilter

from .._logger import logger

_logger = logger

__all__ = ["CondensedNearestNeighbors"]


[docs] class CondensedNearestNeighbors(NoiseFilter): """ Condensed nearest neighbors References: * BibTex:: @ARTICLE{condensed_nn, author={Hart, P.}, journal={IEEE Transactions on Information Theory}, title={The condensed nearest neighbor rule (Corresp.)}, year={1968}, volume={14}, number={3}, pages={515-516}, keywords={Pattern classification}, doi={10.1109/TIT.1968.1054155}, ISSN={0018-9448}, month={May}} """
[docs] def __init__(self, n_jobs=1, **_kwargs): """ Constructor of the noise removing object Args: n_jobs (int): number of jobs """ super().__init__() self.check_n_jobs(n_jobs, "n_jobs") self.n_jobs = n_jobs
[docs] def get_params(self, deep=False): return {"n_jobs": self.n_jobs, **NoiseFilter.get_params(self, deep)}
[docs] def remove_noise(self, X, y): """ Removes noise from dataset Args: X (np.array): features y (np.array): target labels Returns: np.array, np.array: dataset after noise removal """ _logger.info("%s: Running noise removal", self.__class__.__name__) self.class_label_statistics(y) # Initial result set consists of all minority samples and 1 majority # sample X_maj = X[y == self.maj_label] X_hat = np.vstack( # pylint: disable=invalid-name [X[y == self.min_label], X_maj[0]] ) y_hat = np.hstack([np.repeat(self.min_label, len(X_hat) - 1), [self.maj_label]]) X_maj = X_maj[1:] # Adding misclassified majority elements repeatedly while len(X_maj) != 0: knn = KNeighborsClassifier(n_neighbors=1, n_jobs=self.n_jobs) knn.fit(X_hat, y_hat) pred = knn.predict(X_maj) if np.all(pred == self.maj_label): break X_hat = np.vstack( # pylint: disable=invalid-name [X_hat, X_maj[pred != self.maj_label]] ) y_hat = np.hstack( [y_hat, np.repeat(self.maj_label, len(X_hat) - len(y_hat))] ) X_maj = np.delete(X_maj, np.where(pred != self.maj_label)[0], axis=0) return X_hat, y_hat