package it.uniroma2.sag.kelp.kernel.tree;

import com.fasterxml.jackson.annotation.JsonIgnore;
import com.fasterxml.jackson.annotation.JsonTypeName;
import it.uniroma2.sag.kelp.data.representation.tree.TreeRepresentation;
import it.uniroma2.sag.kelp.data.representation.tree.node.TreeNode;
import it.uniroma2.sag.kelp.data.representation.tree.node.TreeNodePairs;
import it.uniroma2.sag.kelp.kernel.DirectKernel;
import it.uniroma2.sag.kelp.kernel.tree.deltamatrix.DeltaMatrix;
import it.uniroma2.sag.kelp.kernel.tree.deltamatrix.StaticDeltaMatrix;
import java.util.ArrayList;
import java.util.List;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

@JsonTypeName("ptk")
/* loaded from: input_file:it/uniroma2/sag/kelp/kernel/tree/PartialTreeKernel.class */
public class PartialTreeKernel extends DirectKernel<TreeRepresentation> {
    private Logger logger;
    private int MAX_CHILDREN;
    private int MAX_RECURSION;
    private float mu;
    private float lambda;

    @JsonIgnore
    private float lambda2;
    private float terminalFactor;
    private int maxSubseqLeng;

    @JsonIgnore
    private DeltaMatrix deltaMatrix;
    private int recursion_id;
    private float[][] kernel_mat_buffer;
    private float[][][] DPS_buffer;
    private float[][][] DP_buffer;

    public PartialTreeKernel() {
        this(0.4f, 0.4f, 1.0f, "0");
    }

    public PartialTreeKernel(float f, float f2, float f3, String str) {
        super(str);
        this.logger = LoggerFactory.getLogger(PartialTreeKernel.class);
        this.MAX_CHILDREN = 50;
        this.MAX_RECURSION = 20;
        this.terminalFactor = 1.0f;
        this.maxSubseqLeng = Integer.MAX_VALUE;
        this.deltaMatrix = StaticDeltaMatrix.getInstance();
        this.recursion_id = 0;
        this.kernel_mat_buffer = new float[this.MAX_RECURSION][this.MAX_CHILDREN];
        this.DPS_buffer = new float[this.MAX_RECURSION][this.MAX_CHILDREN + 1][this.MAX_CHILDREN + 1];
        this.DP_buffer = new float[this.MAX_RECURSION][this.MAX_CHILDREN + 1][this.MAX_CHILDREN + 1];
        this.lambda = f;
        this.lambda2 = f * f;
        this.mu = f2;
        this.terminalFactor = f3;
    }

    public PartialTreeKernel(String str) {
        this(0.4f, 0.4f, 1.0f, str);
    }

    private ArrayList<TreeNodePairs> determineSubList(TreeRepresentation treeRepresentation, TreeRepresentation treeRepresentation2) {
        ArrayList<TreeNodePairs> arrayList = new ArrayList<>();
        int i = 0;
        int i2 = 0;
        List<TreeNode> orderedNodeSetByLabel = treeRepresentation.getOrderedNodeSetByLabel();
        List<TreeNode> orderedNodeSetByLabel2 = treeRepresentation2.getOrderedNodeSetByLabel();
        int size = orderedNodeSetByLabel.size();
        int size2 = orderedNodeSetByLabel2.size();
        while (i < size && i2 < size2) {
            int compareTo = orderedNodeSetByLabel.get(i).getContent().getTextFromData().compareTo(orderedNodeSetByLabel2.get(i2).getContent().getTextFromData());
            if (compareTo > 0) {
                i2++;
            } else if (compareTo < 0) {
                i++;
            } else {
                int i3 = i2;
                while (true) {
                    arrayList.add(new TreeNodePairs(orderedNodeSetByLabel.get(i), orderedNodeSetByLabel2.get(i2)));
                    this.deltaMatrix.add(orderedNodeSetByLabel.get(i).getId().intValue(), orderedNodeSetByLabel2.get(i2).getId().intValue(), -1.0f);
                    i2++;
                    if (i2 >= size2 || !orderedNodeSetByLabel.get(i).getContent().getTextFromData().equals(orderedNodeSetByLabel2.get(i2).getContent().getTextFromData())) {
                        i++;
                        i2 = i3;
                        if (i >= size || !orderedNodeSetByLabel.get(i).getContent().getTextFromData().equals(orderedNodeSetByLabel2.get(i2).getContent().getTextFromData())) {
                            break;
                        }
                    }
                }
                i2 = i2;
            }
        }
        return arrayList;
    }

    public float evaluateKernelNotNormalize(TreeRepresentation treeRepresentation, TreeRepresentation treeRepresentation2) {
        this.deltaMatrix.clear();
        int max = Math.max(treeRepresentation.getBranchingFactor(), treeRepresentation2.getBranchingFactor());
        int max2 = Math.max(treeRepresentation.getHeight(), treeRepresentation2.getHeight());
        if (this.kernel_mat_buffer[0].length < max + 1 || this.DP_buffer.length < max2) {
            if (max >= this.MAX_CHILDREN) {
                this.MAX_CHILDREN = max + 1;
            }
            if (max2 > this.MAX_RECURSION) {
                this.MAX_RECURSION = max2;
            }
            this.logger.warn("Increasing the size of cache matrices to host trees with height=" + this.MAX_RECURSION + " and maxBranchingFactor=" + this.MAX_CHILDREN + "");
            this.kernel_mat_buffer = new float[this.MAX_RECURSION][this.MAX_CHILDREN];
            this.DPS_buffer = new float[this.MAX_RECURSION][this.MAX_CHILDREN][this.MAX_CHILDREN];
            this.DP_buffer = new float[this.MAX_RECURSION][this.MAX_CHILDREN][this.MAX_CHILDREN];
        }
        ArrayList<TreeNodePairs> determineSubList = determineSubList(treeRepresentation, treeRepresentation2);
        float f = 0.0f;
        for (int i = 0; i < determineSubList.size(); i++) {
            f += ptkDeltaFunction(determineSubList.get(i).getNx(), determineSubList.get(i).getNz());
        }
        return f;
    }

    @JsonIgnore
    public DeltaMatrix getDeltaMatrix() {
        return this.deltaMatrix;
    }

    public float getLambda() {
        return this.lambda;
    }

    public float getMu() {
        return this.mu;
    }

    public float getTerminalFactor() {
        return this.terminalFactor;
    }

    @Override // it.uniroma2.sag.kelp.kernel.DirectKernel
    public float kernelComputation(TreeRepresentation treeRepresentation, TreeRepresentation treeRepresentation2) {
        return evaluateKernelNotNormalize(treeRepresentation, treeRepresentation2);
    }

    private float ptkDeltaFunction(TreeNode treeNode, TreeNode treeNode2) {
        if (this.deltaMatrix.get(treeNode.getId().intValue(), treeNode2.getId().intValue()) != -1.0f) {
            return this.deltaMatrix.get(treeNode.getId().intValue(), treeNode2.getId().intValue());
        }
        if (!treeNode.getContent().getTextFromData().equals(treeNode2.getContent().getTextFromData())) {
            this.deltaMatrix.add(treeNode.getId().intValue(), treeNode2.getId().intValue(), 0.0f);
            return 0.0f;
        }
        if (treeNode.getNoOfChildren() == 0 || treeNode2.getNoOfChildren() == 0) {
            this.deltaMatrix.add(treeNode.getId().intValue(), treeNode2.getId().intValue(), this.mu * this.lambda2 * this.terminalFactor);
            return this.mu * this.lambda2 * this.terminalFactor;
        }
        float stringKernelDeltaFunction = this.mu * (this.lambda2 + stringKernelDeltaFunction(treeNode.getChildren(), treeNode2.getChildren()));
        this.deltaMatrix.add(treeNode.getId().intValue(), treeNode2.getId().intValue(), stringKernelDeltaFunction);
        return stringKernelDeltaFunction;
    }

    @JsonIgnore
    @Deprecated
    public void setDeltaMatrix(DeltaMatrix deltaMatrix) {
        this.deltaMatrix = deltaMatrix;
    }

    public void setLambda(float f) {
        this.lambda = f;
        this.lambda2 = this.lambda * this.lambda;
    }

    public void setMu(float f) {
        this.mu = f;
    }

    public void setTerminalFactor(float f) {
        this.terminalFactor = f;
    }

    private float stringKernelDeltaFunction(ArrayList<TreeNode> arrayList, ArrayList<TreeNode> arrayList2) {
        int size = arrayList.size();
        int size2 = arrayList2.size();
        float[][] fArr = this.DPS_buffer[this.recursion_id];
        float[][] fArr2 = this.DP_buffer[this.recursion_id];
        float[] fArr3 = this.kernel_mat_buffer[this.recursion_id];
        this.recursion_id++;
        int i = size;
        if (size2 < size) {
            i = size2;
        }
        if (i > this.maxSubseqLeng) {
            i = this.maxSubseqLeng;
        }
        fArr3[0] = 0.0f;
        for (int i2 = 1; i2 <= size; i2++) {
            for (int i3 = 1; i3 <= size2; i3++) {
                if (arrayList.get(i2 - 1).getContent().getTextFromData().equals(arrayList2.get(i3 - 1).getContent().getTextFromData())) {
                    fArr[i2][i3] = ptkDeltaFunction(arrayList.get(i2 - 1), arrayList2.get(i3 - 1));
                    fArr3[0] = fArr3[0] + fArr[i2][i3];
                } else {
                    fArr[i2][i3] = 0.0f;
                }
            }
        }
        for (int i4 = 1; i4 < i; i4++) {
            fArr3[i4] = 0.0f;
            for (int i5 = 0; i5 <= size2; i5++) {
                fArr2[i4 - 1][i5] = 0.0f;
            }
            for (int i6 = 0; i6 <= size; i6++) {
                fArr2[i6][i4 - 1] = 0.0f;
            }
            for (int i7 = i4; i7 <= size; i7++) {
                for (int i8 = i4; i8 <= size2; i8++) {
                    fArr2[i7][i8] = ((fArr[i7][i8] + (this.lambda * fArr2[i7 - 1][i8])) + (this.lambda * fArr2[i7][i8 - 1])) - (this.lambda2 * fArr2[i7 - 1][i8 - 1]);
                    if (arrayList.get(i7 - 1).getContent().getTextFromData().equals(arrayList2.get(i8 - 1).getContent().getTextFromData())) {
                        fArr[i7][i8] = ptkDeltaFunction(arrayList.get(i7 - 1), arrayList2.get(i8 - 1)) * fArr2[i7 - 1][i8 - 1];
                        int i9 = i4;
                        fArr3[i9] = fArr3[i9] + fArr[i7][i8];
                    }
                }
            }
        }
        float f = 0.0f;
        for (int i10 = 0; i10 < i; i10++) {
            f += fArr3[i10];
        }
        this.recursion_id--;
        return f;
    }

    public int getMaxSubseqLeng() {
        return this.maxSubseqLeng;
    }

    public void setMaxSubseqLeng(int i) {
        this.maxSubseqLeng = i;
    }
}
