package net.sf.javaml.classification;

import java.util.HashMap;
import java.util.Iterator;
import java.util.Map;
import java.util.Set;
import net.sf.javaml.core.Dataset;
import net.sf.javaml.core.Instance;
import net.sf.javaml.core.exception.TrainingRequiredException;
import net.sf.javaml.distance.DistanceMeasure;
import net.sf.javaml.distance.EuclideanDistance;

/* loaded from: input_file:net/sf/javaml/classification/KNearestNeighbors.class */
public class KNearestNeighbors extends AbstractClassifier {
    private static final long serialVersionUID = 1560149339188819924L;
    private Dataset training;
    private int k;
    private DistanceMeasure dm;

    public KNearestNeighbors(int i) {
        this(i, new EuclideanDistance());
    }

    public KNearestNeighbors(int i, DistanceMeasure distanceMeasure) {
        this.k = i;
        this.dm = distanceMeasure;
    }

    @Override // net.sf.javaml.classification.AbstractClassifier, net.sf.javaml.classification.Classifier
    public void buildClassifier(Dataset dataset) {
        this.training = dataset;
    }

    @Override // net.sf.javaml.classification.AbstractClassifier, net.sf.javaml.classification.Classifier
    public Map<Object, Double> classDistribution(Instance instance) {
        if (this.training == null) {
            throw new TrainingRequiredException();
        }
        Set<Instance> kNearest = this.training.kNearest(this.k, instance, this.dm);
        HashMap hashMap = new HashMap();
        Iterator<Object> it2 = this.training.classes().iterator();
        while (it2.hasNext()) {
            hashMap.put(it2.next(), Double.valueOf(0.0d));
        }
        for (Instance instance2 : kNearest) {
            hashMap.put(instance2.classValue(), Double.valueOf(((Double) hashMap.get(instance2.classValue())).doubleValue() + 1.0d));
        }
        double d = this.k;
        double d2 = 0.0d;
        Iterator it3 = hashMap.keySet().iterator();
        while (it3.hasNext()) {
            double doubleValue = ((Double) hashMap.get(it3.next())).doubleValue();
            if (doubleValue > d2) {
                d2 = doubleValue;
            }
            if (doubleValue < d) {
                d = doubleValue;
            }
        }
        if (d2 != d) {
            for (Object obj : hashMap.keySet()) {
                hashMap.put(obj, Double.valueOf((((Double) hashMap.get(obj)).doubleValue() - d) / (d2 - d)));
            }
        }
        return hashMap;
    }
}
