/*
 * Decompiled with CFR 0.152.
 */
package org.jitsi.srtp;

import java.security.GeneralSecurityException;
import java.util.Arrays;
import javax.crypto.AEADBadTagException;
import javax.crypto.spec.SecretKeySpec;
import org.jitsi.srtp.BaseSrtpCryptoContext;
import org.jitsi.srtp.SrtpErrorStatus;
import org.jitsi.srtp.SrtpKdf;
import org.jitsi.srtp.SrtpPolicy;
import org.jitsi.srtp.crypto.Aes;
import org.jitsi.srtp.crypto.JitsiOpenSslProvider;
import org.jitsi.srtp.crypto.SrtpCipher;
import org.jitsi.srtp.crypto.SrtpCipherGcm;
import org.jitsi.srtp.utils.SrtpPacketUtils;
import org.jitsi.utils.ByteArrayBuffer;
import org.jitsi.utils.logging2.Logger;

public class SrtpCryptoContext
extends BaseSrtpCryptoContext {
    protected SrtpCipher cipherAuthOnly;
    private int guessedROC;
    private int roc;
    private int s_l = 0;
    private final boolean sender;
    private boolean seqNumSet = false;

    public SrtpCryptoContext(boolean sender, int ssrc2, int roc, byte[] masterK, byte[] masterS, SrtpPolicy policy, Logger parentLogger) throws GeneralSecurityException {
        super(ssrc2, masterK, masterS, policy, parentLogger);
        this.sender = sender;
        this.roc = roc;
        if (!sender && policy.getEncType() == 5 && JitsiOpenSslProvider.isLoaded()) {
            try {
                this.cipherAuthOnly = new SrtpCipherGcm(new Aes.OpenSSLCipherFactory().createCipher("AES/GCM-AuthOnly/NoPadding"));
            }
            catch (Exception e) {
                this.cipherAuthOnly = this.cipher;
            }
        } else {
            this.cipherAuthOnly = this.cipher;
        }
        this.deriveSrtpKeys(masterK, masterS);
    }

    private SrtpErrorStatus authenticatePacket(ByteArrayBuffer pkt) {
        if (this.policy.getAuthType() != 0) {
            int tagLength = this.policy.getAuthTagLength();
            pkt.readRegionToBuff(pkt.getLength() - tagLength, tagLength, this.tempStore);
            pkt.shrink(tagLength);
            byte[] tagStore = this.authenticatePacketHmac(pkt, this.guessedROC);
            int nonEqual = 0;
            for (int i = 0; i < tagLength; ++i) {
                nonEqual |= this.tempStore[i] ^ tagStore[i];
            }
            if (nonEqual != 0) {
                return SrtpErrorStatus.AUTH_FAIL;
            }
        }
        return SrtpErrorStatus.OK;
    }

    SrtpErrorStatus checkReplay(int seqNo, long guessedIndex) {
        long localIndex = (long)this.roc << 16 | (long)this.s_l;
        long delta = guessedIndex - localIndex;
        if (delta > 0L) {
            return SrtpErrorStatus.OK;
        }
        if (-delta >= 64L) {
            if (this.sender) {
                this.logger.error(() -> "Discarding RTP packet with sequence number " + seqNo + ", SSRC " + (0xFFFFFFFFL & (long)this.ssrc) + " because it is outside the replay window! (roc " + this.roc + ", s_l " + this.s_l + ", guessedROC " + this.guessedROC);
            }
            return SrtpErrorStatus.REPLAY_OLD;
        }
        if ((this.replayWindow >>> (int)(-delta) & 1L) != 0L) {
            if (this.sender) {
                this.logger.error(() -> "Discarding RTP packet with sequence number " + seqNo + ", SSRC " + (0xFFFFFFFFL & (long)this.ssrc) + " because it has been received already! (roc " + this.roc + ", s_l " + this.s_l + ", guessedROC " + this.guessedROC);
            }
            return SrtpErrorStatus.REPLAY_FAIL;
        }
        return SrtpErrorStatus.OK;
    }

    private void deriveSrtpKeys(byte[] masterKey, byte[] masterSalt) throws GeneralSecurityException {
        SrtpKdf kdf = new SrtpKdf(masterKey, masterSalt, this.policy);
        kdf.deriveSessionKey(this.saltKey, (byte)2);
        if (this.cipher != null) {
            byte[] encKey = new byte[this.policy.getEncKeyLength()];
            kdf.deriveSessionKey(encKey, (byte)0);
            this.cipher.init(encKey, this.saltKey);
            if (this.cipherAuthOnly != this.cipher) {
                this.cipherAuthOnly.init(encKey, this.saltKey);
            }
        }
        if (this.mac != null) {
            byte[] authKey = new byte[this.policy.getAuthKeyLength()];
            kdf.deriveSessionKey(authKey, (byte)1);
            this.mac.init(new SecretKeySpec(authKey, this.mac.getAlgorithm()));
            Arrays.fill(authKey, (byte)0);
        }
    }

    private long guessIndex(int seqNo) {
        this.guessedROC = this.s_l < 32768 ? (seqNo - this.s_l > 32768 ? this.roc - 1 : this.roc) : (this.s_l - 32768 > seqNo ? this.roc + 1 : this.roc);
        return (long)this.guessedROC << 16 | (long)seqNo;
    }

    private void processPacketAesCm(ByteArrayBuffer pkt) throws GeneralSecurityException {
        int i;
        int ssrc2 = SrtpPacketUtils.getSsrc(pkt);
        int seqNo = SrtpPacketUtils.getSequenceNumber(pkt);
        long index = (long)this.guessedROC << 16 | (long)seqNo;
        this.ivStore[0] = this.saltKey[0];
        this.ivStore[1] = this.saltKey[1];
        this.ivStore[2] = this.saltKey[2];
        this.ivStore[3] = this.saltKey[3];
        for (i = 4; i < 8; ++i) {
            this.ivStore[i] = (byte)(0xFF & ssrc2 >> (7 - i) * 8 ^ this.saltKey[i]);
        }
        for (i = 8; i < 14; ++i) {
            this.ivStore[i] = (byte)(0xFF & (byte)(index >> (13 - i) * 8) ^ this.saltKey[i]);
        }
        this.ivStore[15] = 0;
        this.ivStore[14] = 0;
        int rtpHeaderLength = SrtpPacketUtils.getTotalHeaderLength(pkt);
        this.cipher.setIV(this.ivStore, 1);
        this.cipher.process(pkt.getBuffer(), pkt.getOffset() + rtpHeaderLength, pkt.getLength() - rtpHeaderLength);
    }

    private SrtpErrorStatus processPacketAesGcm(ByteArrayBuffer pkt, boolean encrypting, boolean skipDecryption) {
        int i;
        int ssrc2 = SrtpPacketUtils.getSsrc(pkt);
        int seqNo = SrtpPacketUtils.getSequenceNumber(pkt);
        long index = (long)this.guessedROC << 16 | (long)seqNo;
        this.ivStore[0] = this.saltKey[0];
        this.ivStore[1] = this.saltKey[1];
        for (i = 2; i < 6; ++i) {
            this.ivStore[i] = (byte)(0xFF & ssrc2 >> (5 - i) * 8 ^ this.saltKey[i]);
        }
        for (i = 6; i < 12; ++i) {
            this.ivStore[i] = (byte)(0xFF & (byte)(index >> (11 - i) * 8) ^ this.saltKey[i]);
        }
        int rtpHeaderLength = SrtpPacketUtils.getTotalHeaderLength(pkt);
        try {
            SrtpCipher cipher = skipDecryption ? this.cipherAuthOnly : this.cipher;
            cipher.setIV(this.ivStore, encrypting ? 1 : 2);
            cipher.processAAD(pkt.getBuffer(), pkt.getOffset(), rtpHeaderLength);
            int processLen = cipher.process(pkt.getBuffer(), pkt.getOffset() + rtpHeaderLength, pkt.getLength() - rtpHeaderLength);
            pkt.setLength(processLen + rtpHeaderLength);
        }
        catch (GeneralSecurityException e) {
            if (encrypting) {
                this.logger.info(() -> "Error encrypting SRTP packet: " + e.getMessage());
                return SrtpErrorStatus.FAIL;
            }
            if (e instanceof AEADBadTagException) {
                return SrtpErrorStatus.AUTH_FAIL;
            }
            this.logger.info(() -> "Error decrypting SRTP packet: " + e.getMessage());
            return SrtpErrorStatus.FAIL;
        }
        return SrtpErrorStatus.OK;
    }

    private void processPacketAesF8(ByteArrayBuffer pkt) throws GeneralSecurityException {
        System.arraycopy(pkt.getBuffer(), pkt.getOffset(), this.ivStore, 0, 12);
        this.ivStore[0] = 0;
        int roc = this.guessedROC;
        this.ivStore[12] = (byte)(roc >> 24);
        this.ivStore[13] = (byte)(roc >> 16);
        this.ivStore[14] = (byte)(roc >> 8);
        this.ivStore[15] = (byte)roc;
        int rtpHeaderLength = SrtpPacketUtils.getTotalHeaderLength(pkt);
        this.cipher.setIV(this.ivStore, 1);
        this.cipher.process(pkt.getBuffer(), pkt.getOffset() + rtpHeaderLength, pkt.getLength() - rtpHeaderLength);
    }

    public synchronized SrtpErrorStatus reverseTransformPacket(ByteArrayBuffer pkt, boolean skipDecryption) throws GeneralSecurityException {
        SrtpErrorStatus ret;
        SrtpErrorStatus err;
        if (this.sender) {
            throw new IllegalStateException("reverseTransformPacket called on SRTP sender");
        }
        if (!SrtpPacketUtils.validatePacketLength(pkt, this.policy.getAuthTagLength())) {
            return SrtpErrorStatus.INVALID_PACKET;
        }
        int seqNo = SrtpPacketUtils.getSequenceNumber(pkt);
        this.logger.debug(() -> "Reverse transform for SSRC " + this.ssrc + " SeqNo=" + seqNo + " s_l=" + this.s_l + " seqNumSet=" + this.seqNumSet + " guessedROC=" + this.guessedROC + " roc=" + this.roc);
        boolean seqNumWasJustSet = false;
        if (!this.seqNumSet) {
            this.seqNumSet = true;
            this.s_l = seqNo;
            seqNumWasJustSet = true;
        }
        long guessedIndex = this.guessIndex(seqNo);
        if (this.policy.isReceiveReplayDisabled() || (err = this.checkReplay(seqNo, guessedIndex)) == SrtpErrorStatus.OK) {
            err = this.authenticatePacket(pkt);
            if (err == SrtpErrorStatus.OK) {
                if (!skipDecryption || this.policy.getEncType() == 5) {
                    switch (this.policy.getEncType()) {
                        case 1: 
                        case 3: {
                            this.processPacketAesCm(pkt);
                            break;
                        }
                        case 5: {
                            err = this.processPacketAesGcm(pkt, false, skipDecryption);
                            break;
                        }
                        case 2: 
                        case 4: {
                            this.processPacketAesF8(pkt);
                        }
                    }
                }
                if (err == SrtpErrorStatus.OK) {
                    this.update(seqNo, guessedIndex);
                } else {
                    this.logger.debug(() -> "SRTP auth failed for SSRC " + this.ssrc);
                }
                ret = err;
            } else {
                this.logger.debug(() -> "SRTP auth failed for SSRC " + this.ssrc);
                ret = err;
            }
        } else {
            ret = err;
        }
        if (ret != SrtpErrorStatus.OK && seqNumWasJustSet) {
            this.seqNumSet = false;
            this.s_l = 0;
        }
        return ret;
    }

    public synchronized SrtpErrorStatus transformPacket(ByteArrayBuffer pkt) throws GeneralSecurityException {
        SrtpErrorStatus err;
        if (!this.sender) {
            throw new IllegalStateException("transformPacket called on SRTP receiver");
        }
        int seqNo = SrtpPacketUtils.getSequenceNumber(pkt);
        if (!this.seqNumSet) {
            this.seqNumSet = true;
            this.s_l = seqNo;
        }
        long guessedIndex = this.guessIndex(seqNo);
        if (this.policy.isSendReplayEnabled() && (err = this.checkReplay(seqNo, guessedIndex)) != SrtpErrorStatus.OK) {
            return err;
        }
        switch (this.policy.getEncType()) {
            case 1: 
            case 3: {
                this.processPacketAesCm(pkt);
                break;
            }
            case 5: {
                this.processPacketAesGcm(pkt, true, false);
                break;
            }
            case 2: 
            case 4: {
                this.processPacketAesF8(pkt);
            }
        }
        if (this.policy.getAuthType() != 0) {
            byte[] tagStore = this.authenticatePacketHmac(pkt, this.guessedROC);
            pkt.append(tagStore, this.policy.getAuthTagLength());
        }
        this.update(seqNo, guessedIndex);
        return SrtpErrorStatus.OK;
    }

    private void logReplayWindow(long newIdx) {
        this.logger.debug(() -> "Updated replay window with " + newIdx + ". " + SrtpPacketUtils.formatReplayWindow(this.roc << 16 | this.s_l, this.replayWindow, 64L));
    }

    private void update(int seqNo, long guessedIndex) {
        long delta = guessedIndex - ((long)this.roc << 16 | (long)this.s_l);
        if (delta >= 64L) {
            this.replayWindow = 1L;
        } else if (delta > 0L) {
            this.replayWindow <<= (int)delta;
            this.replayWindow |= 1L;
        } else {
            this.replayWindow |= 1L << (int)(-delta);
        }
        if (this.guessedROC == this.roc) {
            if (seqNo > this.s_l) {
                this.s_l = seqNo & 0xFFFF;
            }
        } else if (this.guessedROC == this.roc + 1) {
            this.s_l = seqNo & 0xFFFF;
            this.roc = this.guessedROC;
        }
        this.logReplayWindow(guessedIndex);
    }
}

