package it.uniroma2.sag.kelp.predictionfunction.classifier.multiclass;

import it.uniroma2.sag.kelp.data.label.Label;
import it.uniroma2.sag.kelp.predictionfunction.classifier.ClassificationOutput;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;

/* loaded from: input_file:it/uniroma2/sag/kelp/predictionfunction/classifier/multiclass/OneVsAllClassificationOutput.class */
public class OneVsAllClassificationOutput implements ClassificationOutput {
    private HashMap<Label, Float> binaryOutputs = new HashMap<>();
    private Label argmax;

    public void addBinaryPrediction(Label label, float f) {
        if (this.argmax == null) {
            this.argmax = label;
        } else if (this.binaryOutputs.get(this.argmax).floatValue() < f) {
            this.argmax = label;
        }
        this.binaryOutputs.put(label, Float.valueOf(f));
    }

    @Override // it.uniroma2.sag.kelp.predictionfunction.Prediction
    public Float getScore(Label label) {
        return this.binaryOutputs.get(label);
    }

    @Override // it.uniroma2.sag.kelp.predictionfunction.classifier.ClassificationOutput
    public boolean isClassPredicted(Label label) {
        return label.equals(this.argmax);
    }

    @Override // it.uniroma2.sag.kelp.predictionfunction.classifier.ClassificationOutput
    public List<Label> getPredictedClasses() {
        ArrayList arrayList = new ArrayList();
        arrayList.add(this.argmax);
        return arrayList;
    }

    @Override // it.uniroma2.sag.kelp.predictionfunction.classifier.ClassificationOutput
    public List<Label> getAllClasses() {
        ArrayList arrayList = new ArrayList();
        arrayList.addAll(this.binaryOutputs.keySet());
        return arrayList;
    }
}
