package it.uniroma2.sag.kelp.predictionfunction.model;

import com.fasterxml.jackson.annotation.JsonIgnore;
import com.fasterxml.jackson.annotation.JsonTypeName;
import gnu.trove.map.TLongIntMap;
import gnu.trove.map.hash.TLongIntHashMap;
import it.uniroma2.sag.kelp.data.example.Example;
import it.uniroma2.sag.kelp.kernel.Kernel;
import java.util.ArrayList;
import java.util.List;

@JsonTypeName("binarykernelmodel")
/* loaded from: input_file:it/uniroma2/sag/kelp/predictionfunction/model/BinaryKernelMachineModel.class */
public class BinaryKernelMachineModel extends BinaryModel implements KernelMachineModel {
    private Kernel kernel;
    private List<SupportVector> supportVectors;

    @JsonIgnore
    private TLongIntMap fromIdToPosition;

    public BinaryKernelMachineModel(Kernel kernel) {
        this.kernel = null;
        setKernel(kernel);
        this.supportVectors = new ArrayList();
        this.fromIdToPosition = new TLongIntHashMap();
    }

    public BinaryKernelMachineModel() {
        this.kernel = null;
        this.supportVectors = new ArrayList();
        this.fromIdToPosition = new TLongIntHashMap();
    }

    @Override // it.uniroma2.sag.kelp.predictionfunction.model.KernelMachineModel
    public Kernel getKernel() {
        return this.kernel;
    }

    @Override // it.uniroma2.sag.kelp.predictionfunction.model.KernelMachineModel
    public void setKernel(Kernel kernel) {
        this.kernel = kernel;
    }

    public List<SupportVector> getSupportVectors() {
        return this.supportVectors;
    }

    public void setSupportVectors(List<SupportVector> list) {
        this.supportVectors = list;
        this.fromIdToPosition.clear();
        for (int i = 0; i < list.size(); i++) {
            this.fromIdToPosition.put(list.get(i).getInstance().getId(), i + 1);
        }
    }

    public void addSupportVector(SupportVector supportVector) {
        this.fromIdToPosition.put(supportVector.getInstance().getId(), getNumberOfSupportVectors() + 1);
        this.supportVectors.add(supportVector);
    }

    @Override // it.uniroma2.sag.kelp.predictionfunction.model.Model
    public void reset() {
        this.bias = 0.0f;
        this.supportVectors.clear();
        this.fromIdToPosition.clear();
    }

    @Override // it.uniroma2.sag.kelp.predictionfunction.model.BinaryModel
    public void addExample(float f, Example example) {
        SupportVector supportVector = getSupportVector(example);
        if (supportVector == null) {
            addSupportVector(new SupportVector(f, example));
        } else {
            supportVector.incrementWeight(f);
        }
    }

    @Override // it.uniroma2.sag.kelp.predictionfunction.model.BinaryModel
    public float getSquaredNorm(Example example) {
        return this.kernel.squaredNorm(example);
    }

    public SupportVector getSupportVector(Example example) {
        int i = this.fromIdToPosition.get(example.getId());
        if (i == 0) {
            return null;
        }
        return this.supportVectors.get(i - 1);
    }

    public Integer getSupportVectorIndex(Example example) {
        int i = this.fromIdToPosition.get(example.getId());
        if (i == 0) {
            return null;
        }
        return Integer.valueOf(i - 1);
    }

    @Override // it.uniroma2.sag.kelp.predictionfunction.model.KernelMachineModel
    public boolean isSupportVector(Example example) {
        return getSupportVector(example) != null;
    }

    public void setSupportVector(SupportVector supportVector, int i) {
        this.fromIdToPosition.remove(this.supportVectors.get(i).getInstance().getId());
        this.supportVectors.set(i, supportVector);
        this.fromIdToPosition.put(supportVector.getInstance().getId(), i);
    }

    public void substituteSupportVector(int i, Example example, float f) {
        this.supportVectors.get(i).setInstance(example);
        this.supportVectors.get(i).setWeight(f);
        this.fromIdToPosition.put(example.getId(), i);
    }

    @Override // it.uniroma2.sag.kelp.predictionfunction.model.KernelMachineModel
    public int getNumberOfSupportVectors() {
        return this.supportVectors.size();
    }

    @Override // it.uniroma2.sag.kelp.predictionfunction.model.BinaryModel
    @JsonIgnore
    public float getSquaredNorm() {
        float f = 0.0f;
        for (SupportVector supportVector : this.supportVectors) {
            for (SupportVector supportVector2 : this.supportVectors) {
                f += supportVector.getWeight() * supportVector2.getWeight() * this.kernel.innerProduct(supportVector.getInstance(), supportVector2.getInstance());
            }
        }
        return f;
    }
}
