package it.uniroma2.sag.kelp.learningalgorithm.classification.multiclassification;

import com.fasterxml.jackson.annotation.JsonIgnore;
import com.fasterxml.jackson.annotation.JsonTypeName;
import it.uniroma2.sag.kelp.data.dataset.Dataset;
import it.uniroma2.sag.kelp.data.label.Label;
import it.uniroma2.sag.kelp.learningalgorithm.LearningAlgorithm;
import it.uniroma2.sag.kelp.learningalgorithm.MetaLearningAlgorithm;
import it.uniroma2.sag.kelp.learningalgorithm.classification.ClassificationLearningAlgorithm;
import it.uniroma2.sag.kelp.predictionfunction.PredictionFunction;
import it.uniroma2.sag.kelp.predictionfunction.classifier.Classifier;
import it.uniroma2.sag.kelp.predictionfunction.classifier.multiclass.MultiLabelClassifier;
import java.util.Arrays;
import java.util.List;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

@JsonTypeName("multiLabel")
/* loaded from: input_file:it/uniroma2/sag/kelp/learningalgorithm/classification/multiclassification/MultiLabelClassificationLearning.class */
public class MultiLabelClassificationLearning implements ClassificationLearningAlgorithm, MetaLearningAlgorithm {
    private LearningAlgorithm baseAlgorithm;

    @JsonIgnore
    private LearningAlgorithm[] algorithms;
    private List<Label> labels;
    private Logger logger = LoggerFactory.getLogger(MultiLabelClassificationLearning.class);

    @JsonIgnore
    private MultiLabelClassifier classifier = new MultiLabelClassifier();

    @Override // it.uniroma2.sag.kelp.learningalgorithm.LearningAlgorithm
    public void setLabels(List<Label> list) {
        this.labels = list;
        this.classifier.setLabels(list);
    }

    @Override // it.uniroma2.sag.kelp.learningalgorithm.LearningAlgorithm
    public List<Label> getLabels() {
        return this.labels;
    }

    private void initialize() {
        this.algorithms = new LearningAlgorithm[this.labels.size()];
        Classifier[] classifierArr = new Classifier[this.labels.size()];
        for (int i = 0; i < this.labels.size(); i++) {
            try {
                this.algorithms[i] = this.baseAlgorithm.duplicate();
                this.algorithms[i].setLabels(Arrays.asList(this.labels.get(i)));
                classifierArr[i] = (Classifier) this.algorithms[i].getPredictionFunction();
            } catch (Exception e) {
                this.logger.error(e.getMessage());
                e.printStackTrace();
                System.exit(0);
            }
        }
        this.classifier.setBinaryClassifiers(classifierArr);
    }

    @Override // it.uniroma2.sag.kelp.learningalgorithm.LearningAlgorithm
    public void learn(Dataset dataset) {
        if (this.algorithms == null) {
            initialize();
        }
        for (int i = 0; i < this.labels.size(); i++) {
            try {
                this.algorithms[i].learn(dataset);
            } catch (Exception e) {
                e.printStackTrace();
            }
        }
    }

    @Override // it.uniroma2.sag.kelp.learningalgorithm.LearningAlgorithm
    public void reset() {
        if (this.algorithms != null) {
            for (LearningAlgorithm learningAlgorithm : this.algorithms) {
                learningAlgorithm.reset();
            }
        }
    }

    @Override // it.uniroma2.sag.kelp.learningalgorithm.LearningAlgorithm
    public MultiLabelClassifier getPredictionFunction() {
        return this.classifier;
    }

    @Override // it.uniroma2.sag.kelp.learningalgorithm.MetaLearningAlgorithm
    public void setBaseAlgorithm(LearningAlgorithm learningAlgorithm) {
        this.baseAlgorithm = learningAlgorithm;
    }

    @Override // it.uniroma2.sag.kelp.learningalgorithm.MetaLearningAlgorithm
    public LearningAlgorithm getBaseAlgorithm() {
        return this.baseAlgorithm;
    }

    @Override // it.uniroma2.sag.kelp.learningalgorithm.LearningAlgorithm
    public MultiLabelClassificationLearning duplicate() {
        MultiLabelClassificationLearning multiLabelClassificationLearning = new MultiLabelClassificationLearning();
        multiLabelClassificationLearning.setBaseAlgorithm(this.baseAlgorithm);
        return multiLabelClassificationLearning;
    }

    @Override // it.uniroma2.sag.kelp.learningalgorithm.LearningAlgorithm
    public void setPredictionFunction(PredictionFunction predictionFunction) {
        this.classifier = (MultiLabelClassifier) predictionFunction;
        this.algorithms = new LearningAlgorithm[this.classifier.getLabels().size()];
        this.labels = this.classifier.getLabels();
        for (int i = 0; i < this.classifier.getLabels().size(); i++) {
            this.algorithms[i] = this.baseAlgorithm.duplicate();
            this.algorithms[i].setLabels(Arrays.asList(this.labels.get(i)));
            this.algorithms[i].setPredictionFunction(this.classifier.getBinaryClassifiers()[i]);
        }
    }
}
