package it.uniroma2.sag.kelp.predictionfunction.classifier.multiclass;

import com.fasterxml.jackson.annotation.JsonTypeName;
import it.uniroma2.sag.kelp.data.example.Example;
import it.uniroma2.sag.kelp.data.label.Label;
import it.uniroma2.sag.kelp.predictionfunction.classifier.ClassificationOutput;
import it.uniroma2.sag.kelp.predictionfunction.classifier.Classifier;
import it.uniroma2.sag.kelp.predictionfunction.model.BinaryModel;
import it.uniroma2.sag.kelp.predictionfunction.model.Model;
import it.uniroma2.sag.kelp.predictionfunction.model.MulticlassModel;
import java.util.ArrayList;
import java.util.List;

@JsonTypeName("multilabelClassifier")
/* loaded from: input_file:it/uniroma2/sag/kelp/predictionfunction/classifier/multiclass/MultiLabelClassifier.class */
public class MultiLabelClassifier implements Classifier {
    private Classifier[] binaryClassifiers;
    private List<Label> labels;
    private MulticlassModel model = new MulticlassModel();

    public Classifier[] getBinaryClassifiers() {
        return this.binaryClassifiers;
    }

    public void setBinaryClassifiers(Classifier[] classifierArr) {
        this.binaryClassifiers = classifierArr;
        ArrayList arrayList = new ArrayList();
        for (Classifier classifier : classifierArr) {
            arrayList.add((BinaryModel) classifier.getModel());
        }
        this.model.setModels(arrayList);
    }

    @Override // it.uniroma2.sag.kelp.predictionfunction.PredictionFunction
    public MultiLabelClassificationOutput predict(Example example) {
        MultiLabelClassificationOutput multiLabelClassificationOutput = new MultiLabelClassificationOutput();
        for (Classifier classifier : this.binaryClassifiers) {
            ClassificationOutput predict = classifier.predict(example);
            Label label = predict.getAllClasses().get(0);
            multiLabelClassificationOutput.addBinaryPrediction(label, predict.getScore(label).floatValue());
        }
        return multiLabelClassificationOutput;
    }

    @Override // it.uniroma2.sag.kelp.predictionfunction.PredictionFunction
    public void reset() {
        for (Classifier classifier : this.binaryClassifiers) {
            classifier.reset();
        }
    }

    @Override // it.uniroma2.sag.kelp.predictionfunction.PredictionFunction
    public void setLabels(List<Label> list) {
        this.labels = list;
    }

    @Override // it.uniroma2.sag.kelp.predictionfunction.PredictionFunction
    public List<Label> getLabels() {
        return this.labels;
    }

    @Override // it.uniroma2.sag.kelp.predictionfunction.PredictionFunction
    public MulticlassModel getModel() {
        return this.model;
    }

    @Override // it.uniroma2.sag.kelp.predictionfunction.PredictionFunction
    public void setModel(Model model) {
        this.model = (MulticlassModel) model;
    }
}
