package it.uniroma2.sag.kelp.predictionfunction;

import com.fasterxml.jackson.annotation.JsonIgnore;
import it.uniroma2.sag.kelp.data.example.Example;
import it.uniroma2.sag.kelp.data.example.SequenceExample;
import it.uniroma2.sag.kelp.data.example.SequencePath;
import it.uniroma2.sag.kelp.data.label.Label;
import it.uniroma2.sag.kelp.data.label.SequenceEmission;
import it.uniroma2.sag.kelp.predictionfunction.model.Model;
import it.uniroma2.sag.kelp.predictionfunction.model.SequenceModel;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Vector;
import org.apache.commons.lang3.SerializationUtils;

/* loaded from: input_file:it/uniroma2/sag/kelp/predictionfunction/SequencePredictionFunction.class */
public class SequencePredictionFunction implements PredictionFunction {
    public static final int DEFAULT_MAX_EMISSION_CAND = 5;
    public static final int DEFAULT_BEAM_SIZE = 20;
    private int maxEmissionCandidates;
    private int beamSize;
    private SequenceModel model;

    public SequencePredictionFunction() {
        this.maxEmissionCandidates = 5;
        this.beamSize = 20;
    }

    public SequencePredictionFunction(SequenceModel sequenceModel) {
        this();
        this.model = sequenceModel;
    }

    public int getBeamSize() {
        return this.beamSize;
    }

    private HashMap<Label, Float> getEmissionsProbabilities(Prediction prediction) {
        double d = 0.0d;
        List<Label> labels = this.model.getBasePredictionFunction().getLabels();
        Iterator<Label> it2 = labels.iterator();
        while (it2.hasNext()) {
            d += Math.exp(prediction.getScore(it2.next()).floatValue());
        }
        HashMap<Label, Float> hashMap = new HashMap<>();
        Iterator<Label> it3 = labels.iterator();
        while (it3.hasNext()) {
            hashMap.put(it3.next(), Float.valueOf((float) (Math.exp(prediction.getScore(r0).floatValue()) / d)));
        }
        return hashMap;
    }

    @Override // it.uniroma2.sag.kelp.predictionfunction.PredictionFunction
    @JsonIgnore
    public List<Label> getLabels() {
        return this.model.getBasePredictionFunction().getLabels();
    }

    public int getMaxEmissionCandidates() {
        return this.maxEmissionCandidates;
    }

    @Override // it.uniroma2.sag.kelp.predictionfunction.PredictionFunction
    public SequenceModel getModel() {
        return this.model;
    }

    /* JADX WARN: Multi-variable type inference failed */
    /* JADX WARN: Type inference failed for: r0v89, types: [java.util.List] */
    @Override // it.uniroma2.sag.kelp.predictionfunction.PredictionFunction
    public SequencePrediction predict(Example example) {
        SequenceExample sequenceExample = (SequenceExample) example;
        int lenght = sequenceExample.getLenght();
        ArrayList<SequencePath> arrayList = new ArrayList();
        SequencePath sequencePath = new SequencePath();
        arrayList.add(sequencePath);
        for (int i = 0; i < lenght; i++) {
            ArrayList arrayList2 = new ArrayList();
            for (SequencePath sequencePath2 : arrayList) {
                HashMap<Label, Float> emissionsProbabilities = getEmissionsProbabilities(this.model.getBasePredictionFunction().predict(this.model.getSequenceExampleGenerator().generateExampleWithHistory(sequenceExample, sequencePath2, i)));
                Vector vector = new Vector();
                for (Label label : getLabels()) {
                    vector.add(new SequenceEmission(label, emissionsProbabilities.get(label).floatValue()));
                }
                Collections.sort(vector);
                Collections.reverse(vector);
                if (this.model.getSequenceExampleGenerator().getTransitionsOrder() == 0) {
                    SequenceEmission sequenceEmission = (SequenceEmission) vector.get(0);
                    sequencePath2.add(new SequenceEmission(sequenceEmission.getLabel(), sequenceEmission.getEmission()));
                    sequencePath2.setScore(Double.valueOf(sequencePath.getScore().doubleValue() + Math.log(sequenceEmission.getEmission())));
                    arrayList2.add(sequencePath2);
                } else {
                    if (vector.size() >= this.maxEmissionCandidates) {
                        vector = vector.subList(0, this.maxEmissionCandidates);
                    }
                    Iterator it2 = vector.iterator();
                    while (it2.hasNext()) {
                        Label label2 = ((SequenceEmission) it2.next()).getLabel();
                        SequencePath sequencePath3 = (SequencePath) SerializationUtils.clone(sequencePath2);
                        float floatValue = emissionsProbabilities.get(label2).floatValue();
                        sequencePath3.getAssignedSequnceLabels().add(new SequenceEmission(label2, floatValue));
                        sequencePath3.setScore(Double.valueOf(sequencePath3.getScore().doubleValue() + Math.log(floatValue)));
                        arrayList2.add(sequencePath3);
                    }
                }
            }
            Collections.sort(arrayList2);
            Collections.reverse(arrayList2);
            if (arrayList2.size() > this.beamSize) {
                for (int size = arrayList2.size() - 1; size >= this.beamSize; size--) {
                    arrayList2.remove(size);
                }
            }
            arrayList = arrayList2;
        }
        SequencePrediction sequencePrediction = new SequencePrediction();
        sequencePrediction.setPaths(arrayList);
        return sequencePrediction;
    }

    @Override // it.uniroma2.sag.kelp.predictionfunction.PredictionFunction
    public void reset() {
        this.model.getBasePredictionFunction().reset();
        this.model.reset();
    }

    public void setBeamSize(int i) {
        this.beamSize = i;
    }

    @Override // it.uniroma2.sag.kelp.predictionfunction.PredictionFunction
    @JsonIgnore
    public void setLabels(List<Label> list) {
        this.model.getBasePredictionFunction().setLabels(list);
    }

    public void setMaxEmissionCandidates(int i) {
        this.maxEmissionCandidates = i;
    }

    @Override // it.uniroma2.sag.kelp.predictionfunction.PredictionFunction
    public void setModel(Model model) {
        this.model = (SequenceModel) model;
    }
}
