#![allow(non_camel_case_types)]
use crate::processor::elements::cheri::SafeTaggedCap;
use std::ops::Range;
use std::marker::PhantomData;
use crate::processor::isa_mods::*;
use crate::processor::exceptions::IllegalInstructionException::*;
use super::csrs::CSRProvider;
use std::cmp::min;
use anyhow::{Context, Result};
use crate::processor::decode::{Opcode,InstructionBits};
mod types;
pub use types::*;
mod conns;
pub use conns::*;
mod decode;
pub use decode::*;
mod registers;
pub use registers::*;
pub struct Rvv<uXLEN: PossibleXlen, TElem> {
    vreg: Box<dyn VectorRegisterFile<TElem>>,
    vtype: VType,
    vl: u32,
    
    
    
    
    
    
    
    vstart: u32,
    _phantom_xlen: PhantomData<uXLEN>,
}
pub type Rv32v = Rvv<u32, u128>;
pub type Rv64v = Rvv<u64, u128>;
pub type Rv64Cheriv = Rvv<u64, SafeTaggedCap>;
impl<uXLEN: PossibleXlen, TElem> Rvv<uXLEN, TElem> {
    
    pub fn new(vreg: Box<dyn VectorRegisterFile<TElem>>) -> Self {
        Rvv {
            vreg,
            vtype: VType::illegal(),
            vl: 0,
            vstart: 0,
            _phantom_xlen: PhantomData,
        }
    }
    
    pub fn reset(&mut self) {
        self.vreg.reset();
        self.vtype = VType::illegal();
        self.vl = 0;
        self.vstart = 0;
    }
    
    
    
    
    
    
    
    
    fn exec_config(&mut self, inst_kind: ConfigKind, inst: InstructionBits, sreg: &mut dyn VecRegInterface<uXLEN>) -> Result<()> {
        if let InstructionBits::VType{rd, funct3, rs1, rs2, zimm11, zimm10, ..} = inst {
            assert_eq!(funct3, 0b111);
            
            
            
            
            let avl = match inst_kind {
                ConfigKind::vsetvli | ConfigKind::vsetvl => { 
                    
                    if rs1 != 0 {
                        
                        sreg.sreg_read_xlen(rs1)?.into()
                    } else {
                        if rd != 0 {
                            
                            
                            
                            
                            u64::MAX
                        } else {
                            
                            self.vl as u64
                        }
                    }
                } ,
                ConfigKind::vsetivli => { 
                    
                    
                    rs1 as u64
                }
            };
            
            
            let vtype_bits = match inst_kind {
                ConfigKind::vsetvli => {
                    zimm11 as u64
                },
                ConfigKind::vsetivli => {
                    zimm10 as u64
                },
                ConfigKind::vsetvl => {
                    sreg.sreg_read_xlen(rs2)?.into()
                },
            };
            
            let req_vtype = VType::decode(vtype_bits as u32)?;
            
            
            let elems_per_group = req_vtype.elems_per_group();
            let vtype_supported = elems_per_group > 0;
            if vtype_supported {
                self.vtype = req_vtype;
                
                self.vl = min(elems_per_group, avl as u32);
                sreg.sreg_write_xlen(rd, self.vl.into())?;
            } else {
                self.vtype = VType::illegal();
                
                
                bail!("Valid but unsupported vtype: {:b} -> {:?}, elems_per_group {}", vtype_bits, req_vtype, elems_per_group);
            }
            Ok(())
        } else {
            unreachable!("vector::exec_config instruction MUST be InstructionBits::VType, got {:?} instead", inst);
        }
    }
    
    
    fn get_active_segment_range(&mut self, vm: bool, evl: u32) -> Option<Range<u32>> {
        
        
        
        
        
        
        let start = (self.vstart..evl)
            .filter_map(|i| match self.vreg.seg_masked_out(vm, i) {
                true => None,
                false => Some(i as u32)
            })
            .min();
        
        
        
        
        let final_accessed = (self.vstart..evl)
            .filter_map(|i| match self.vreg.seg_masked_out(vm, i) {
                true => None,
                false => Some(i as u32)
            })
            .max();
        
        
        match (start, final_accessed) {
            (Some(start), Some(final_accessed)) => Some(Range::<u32> {
                start,
                end: final_accessed + 1 
            }),
            _ => None
        }
    }
    
    
    
    
    
    
    
    
    
    
    fn fast_check_load_store(&mut self, addr_provenance: (u64, Provenance), rs2: u8, vm: bool, op: DecodedMemOp, sreg: &mut dyn VecRegInterface<uXLEN>) -> (Result<bool>, Range<u64>) {
        let (base_addr, provenance) = addr_provenance;
        use DecodedMemOp::*;
        let mut is_fault_only_first = false;
        
        let addr_range = match op {
            Strided{stride, eew, evl, nf, ..} => {
                
                
                
                let Range{ start: active_vstart, end: active_evl } = self.get_active_segment_range(vm, evl).unwrap();
                
                let offset_range = Range::<u64> {
                    start: active_vstart as u64 * stride,
                    
                    
                    
                    end: (active_evl as u64 - 1) * stride + (nf as u64) * eew.width_in_bytes()
                };
                Range::<u64> {
                    start: base_addr + offset_range.start,
                    end: base_addr + offset_range.end,
                }
            },
            FaultOnlyFirst{evl, nf, eew, ..} => {
                is_fault_only_first = true;
                
                let Range{ start: active_vstart, end: active_evl } = self.get_active_segment_range(vm, evl).unwrap();
                let offset_range = Range::<u64> {
                    start: (active_vstart as u64) * (nf as u64) * eew.width_in_bytes(),
                    
                    
                    
                    
                    
                    
                    end:   (active_evl as u64)    * (nf as u64) * eew.width_in_bytes()
                };
                Range::<u64> {
                    start: base_addr + offset_range.start,
                    end: base_addr + offset_range.end,
                }
            },
            Indexed{evl, nf, eew, index_ew, ..} => {
                
                let Range{ start: active_vstart, end: active_evl } = self.get_active_segment_range(vm, evl).unwrap();
                let mut offsets = vec![];
                for i_segment in active_vstart..active_evl {
                    offsets.push(self.vreg.load_vreg_elem_int(index_ew, rs2, i_segment).unwrap());
                }
                let offset_range = Range::<u64> {
                    start: *offsets.iter().min().unwrap() as u64,
                    end: *offsets.iter().max().unwrap() as u64 + (nf as u64 * eew.width_in_bytes()),
                };
                Range::<u64> {
                    start: base_addr + offset_range.start,
                    end: base_addr + offset_range.end,
                }
            }
            WholeRegister{eew, ..} => {
                
                
                let index_range = Range::<u64> {
                    start: 0,
                    end: (op.evl() as u64)
                };
                Range::<u64> {
                    start: base_addr + index_range.start * eew.width_in_bytes(),
                    end: base_addr + index_range.end * eew.width_in_bytes(),
                }
            }
            ByteMask{evl, ..} => {
                
                
                let index_range = Range::<u64> {
                    start: self.vstart as u64,
                    end: (evl as u64)
                };
                Range::<u64> {
                    start: base_addr + index_range.start,
                    end: base_addr + index_range.end,
                }
            }
        };
        let check_result = sreg.check_addr_range_against_provenance(addr_range.clone(), provenance, op.dir());
        match check_result {
            Ok(()) => {
                
                return (Ok(true), addr_range);
            }
            Err(e) => {
                
                
                if is_fault_only_first {
                    return (Ok(false), addr_range);
                }
                
                
                
                return (Err(e), addr_range);
            }
        }
    }
    
    fn get_load_store_accesses(&mut self, rd: u8, addr_p: (u64, Provenance), rs2: u8, vm: bool, op: DecodedMemOp) -> Result<Vec<(VectorElem, u64)>> {
        let mut map = vec![];
        let (base_addr, _) = addr_p;
        use DecodedMemOp::*;
        match op {
            Strided{stride, evl, nf, eew, emul, ..} => {
                
                for i_segment in self.vstart..evl {
                    let seg_addr = base_addr + (i_segment as u64 * stride);
                    
                    if !self.vreg.seg_masked_out(vm, i_segment) {
                        
                        let mut field_addr = seg_addr;
                        for i_field in 0..nf {
                            
                            let vec_elem = VectorElem::check_with_lmul(
                                rd + (i_field * emul.num_registers_consumed()),
                                eew, emul,
                                i_segment
                            );
                            map.push((vec_elem, field_addr));
                            
                            field_addr += eew.width_in_bytes();
                        }
                    }
                }
            }
            FaultOnlyFirst{evl, nf, eew, emul} => {
                
                
                
                let stride = eew.width_in_bytes() * (nf as u64);
                
                for i_segment in self.vstart..evl {
                    let seg_addr = base_addr + (i_segment as u64 * stride);
                    
                    if !self.vreg.seg_masked_out(vm, i_segment) {
                        
                        let mut field_addr = seg_addr;
                        for i_field in 0..nf {
                            
                            let vec_elem = VectorElem::check_with_lmul(
                                rd + (i_field * emul.num_registers_consumed()),
                                eew, emul,
                                i_segment
                            );
                            map.push((vec_elem, field_addr));
                            
                            field_addr += eew.width_in_bytes();
                        }
                    }
                }
            }
            Indexed{index_ew, evl, nf, eew, emul, ..} => {
                
                for i_segment in self.vstart..evl {
                    
                    let seg_offset = self.vreg.load_vreg_elem_int(index_ew, rs2, i_segment)?;
                    let seg_addr = base_addr + seg_offset as u64;
                    
                    if !self.vreg.seg_masked_out(vm, i_segment) {
                        
                        let mut field_addr = seg_addr;
                        for i_field in 0..nf {
                            
                            let vec_elem = VectorElem::check_with_lmul(
                                rd + (i_field * emul.num_registers_consumed()),
                                eew, emul,
                                i_segment
                            );
                            map.push((vec_elem, field_addr));
                            
                            field_addr += eew.width_in_bytes();
                        }
                    }
                }
            }
            WholeRegister{num_regs, eew, ..} => {
                if vm == false {
                    
                    bail!("WholeRegister operations cannot be masked")
                }
                let mut addr = base_addr;
                let vl = op.evl();
                for i in 0..vl {
                    let vec_elem = VectorElem::check_with_num_regs(rd, eew, num_regs, i as u32);
                    map.push((vec_elem, addr));
                    addr += eew.width_in_bytes();
                }
            }
            ByteMask{evl, ..} => {
                if vm == false {
                    
                    bail!("ByteMask operations cannot be masked")
                }
                let mut addr = base_addr;
                for i in self.vstart..evl {
                    let vec_elem = VectorElem::check_with_lmul(
                        rd,
                        Sew::e8, Lmul::e1,
                        i
                    );
                    map.push((vec_elem, addr));
                    addr += 1;
                }
            }
        };
        Ok(map)
    }
    
    fn exec_load_store(&mut self, expected_addr_range: Range<u64>, rd: u8, rs1: u8, rs2: u8, vm: bool, op: DecodedMemOp, sreg: &mut dyn VecRegInterface<uXLEN>, mem: &mut dyn VecMemInterface<uXLEN, TElem>) -> Result<()> {
        
        let addr_p = sreg.get_addr_provenance(rs1)?;
        let accesses = self.get_load_store_accesses(rd, addr_p, rs2, vm, op)?;
        let (_, provenance) = addr_p;
        
        
        let min_addr = accesses.iter().map(|(_, addr)| *addr).min().unwrap();
        
        let max_addr = accesses.iter().map(|(elem, addr)| addr + elem.eew.width_in_bytes()).max().unwrap();
        if expected_addr_range.start != min_addr || expected_addr_range.end != max_addr {
            bail!("Computed fast-path address range 0x{:x}-{:x} doesn't match the min/max accessed addresses 0x{:x}-{:x}",
                expected_addr_range.start, expected_addr_range.end,
                min_addr, max_addr
            );
        }
        use DecodedMemOp::*;
        match op {
            Strided{dir, ..} | Indexed{dir, ..} | WholeRegister{dir, ..} | ByteMask{dir, ..} => {
                
                for (VectorElem{ base_reg, eew, elem_within_group, ..}, addr) in accesses {
                    let addr_p = (addr, provenance);
                    
                    match dir {
                        MemOpDir::Load => self.load_to_vreg(mem, eew, addr_p, base_reg, elem_within_group)
                            .with_context(|| format!("Failure on element {}", elem_within_group))?,
                        MemOpDir::Store => self.store_to_mem(mem, eew, addr_p, base_reg, elem_within_group)
                            .with_context(|| format!("Failure on element {}", elem_within_group))?
                    }
                }
            }
            FaultOnlyFirst{..} => {
                
                for (VectorElem{ base_reg, eew, elem_within_group, ..}, addr) in accesses {
                    let addr_p = (addr, provenance);
                    
                    let load_fault: Result<()> = 
                        self.load_to_vreg(mem, eew, addr_p, base_reg, elem_within_group);
                    
                    
                    if elem_within_group == 0 {
                        
                        load_fault?;
                    } else if load_fault.is_err() {
                        use crate::processor::exceptions::{MemoryException, CapabilityException};
                        
                        let load_err = load_fault.unwrap_err();
                        
                        let mut error_reduces_vlen = match load_err.downcast_ref::<MemoryException>() {
                            Some(MemoryException::AddressUnmapped{..}) => true,
                            _ => false
                        };
                        
                        match load_err.downcast_ref::<CapabilityException>() {
                            Some(_) => { error_reduces_vlen = true; },
                            _ => {}
                        };
                        if error_reduces_vlen {
                            
                            
                            self.vl = elem_within_group;
                            
                            break;
                        } else {
                            
                            return Err(load_err)
                        }
                    }
                }
            }
        };
        Ok(())
    }
    
    
    fn load_to_vreg(&mut self, mem: &mut dyn VecMemInterface<uXLEN, TElem>, eew: Sew, addr_provenance: (u64, Provenance), vd_base: u8, idx_from_base: u32) -> Result<()> {
        let val = mem.load_from_memory(eew, addr_provenance)?;
        self.vreg.store_vreg_elem(eew, vd_base, idx_from_base, val)?;
        Ok(())
    }
    
    
    fn store_to_mem(&mut self, mem: &mut dyn VecMemInterface<uXLEN, TElem>, eew: Sew, addr_provenance: (u64, Provenance), vd_base: u8, idx_from_base: u32) -> Result<()> {
        let val = self.vreg.load_vreg_elem(eew, vd_base, idx_from_base)?;
        mem.store_to_memory(eew, val, addr_provenance)?;
        Ok(())
    }
    
    pub fn dump(&self) {
        self.vreg.dump();
        println!("vl: {}\nvtype: {:?}", self.vl, self.vtype);
    }
}
pub type VecInterface<'a, uXLEN, TElem> = (
    &'a mut dyn VecRegInterface<uXLEN>,
    &'a mut dyn VecMemInterface<uXLEN, TElem>
);
impl<uXLEN: PossibleXlen, TElem> IsaMod<VecInterface<'_, uXLEN, TElem>> for Rvv<uXLEN, TElem> {
    type Pc = ();
    fn will_handle(&self, opcode: Opcode, inst: InstructionBits) -> bool {
        use crate::processor::decode::Opcode::*;
        match (opcode, inst) {
            
            (Vector, _) => true,
            (LoadFP | StoreFP, InstructionBits::FLdStType{width, ..}) => {
                
                match width {
                    0b0001 | 0b0010 | 0b0011 | 0b0100 => false,
                    0b1000..=0b1111 => false,
                    
                    _ => true
                }
            },
            _ => false
        }
    }
    
    
    
    
    
    
    
    
    
    
    fn execute(&mut self, opcode: Opcode, inst: InstructionBits, inst_bits: u32, conn: VecInterface<'_, uXLEN, TElem>) -> ProcessorResult<Option<()>> {
        let (sreg, mem) = conn;
        use Opcode::*;
        match (opcode, inst) {
            (Vector, InstructionBits::VType{funct3, funct6, rs1, rs2, rd, vm, ..}) => {
                match funct3 {
                    0b111 => {
                        
                        let inst_kind = match bits!(inst_bits, 30:31) {
                            0b00 | 0b01 => ConfigKind::vsetvli,
                            0b11 => ConfigKind::vsetivli,
                            0b10 => ConfigKind::vsetvl,
                            invalid => panic!("impossible top 2 bits {:2b}", invalid)
                        };
                        self.exec_config(inst_kind, inst, sreg)?
                    }
                    0b000 => {
                        
                        let vs1 = rs1;
                        let vd = rd;
                        match funct6 {
                            0b010111 => {
                                
                                if !vm {
                                    bail!("vector-vector move can't be masked");
                                }
                                for i in self.vstart..self.vl {
                                    let val = self.vreg.load_vreg_elem(self.vtype.vsew, vs1, i)?;
                                    self.vreg.store_vreg_elem(self.vtype.vsew, vd, i, val)?;
                                }
                            }
                            _ => bail!("Unsupported OPIVV funct6 {:b}", funct6)
                        }
                    }
                    
                    
                    
                    0b011 => {
                        
                        
                        let imm = rs1 as u128;
                        match funct6 {
                            0b011000 => {
                                
                                
                                
                                let mut val: uVLEN = 0;
                                for i in self.vstart..self.vl {
                                    let reg_val = self.vreg.load_vreg_elem_int(self.vtype.vsew, rs2, i)?;
                                    if reg_val == imm {
                                        val |= (1 as uVLEN) << i;
                                    }
                                }
                                self.vreg.store_vreg_int(rd, val)?;
                            }
                            0b011001 => {
                                
                                
                                
                                let mut val: uVLEN = 0;
                                for i in self.vstart..self.vl {
                                    if self.vreg.load_vreg_elem_int(self.vtype.vsew, rs2, i)? != imm {
                                        val |= (1 as uVLEN) << i;
                                    }
                                }
                                self.vreg.store_vreg_int(rd, val)?;
                            }
                            0b010111 => {
                                if (!vm) && rd == 0 {
                                    bail!(UnsupportedParam("Can't handle vmerge on the mask register, because it uses the mask register :)".to_string()));
                                }
                                
                                
                                for i in self.vstart..self.vl {
                                    let val = if self.vreg.seg_masked_out(vm, i) {
                                        
                                        self.vreg.load_vreg_elem_int(self.vtype.vsew, rs2, i)?
                                    } else {
                                        
                                        
                                        imm
                                    };
                                    self.vreg.store_vreg_elem_int(self.vtype.vsew, rd, i, val)?;
                                }
                            }
                            0b100111 => {
                                if vm == true {
                                    
                                    
                                    
                                    
                                    
                                    
                                    let nr = rs1 + 1;
                                    let emul = match nr {
                                        1 => Lmul::e1,
                                        2 => Lmul::e2,
                                        4 => Lmul::e4,
                                        8 => Lmul::e8,
                                        _ => bail!(UnsupportedParam(format!("Invalid nr encoding in vmv<nr>r.v: nr = {}", nr)))
                                    };
                                    let eew = self.vtype.vsew;
                                    let evl = val_times_lmul_over_sew(VLEN as u32, eew, emul);
                                    if self.vstart >= evl {
                                        bail!(UnsupportedParam(format!("evl {} <= vstart {} therefore vector move is no op", evl, self.vstart)))
                                    }
                                    if rd == rs2 {
                                        
                                        return Ok(None)
                                    }
                                    for vx in 0..nr {
                                        let val = self.vreg.load_vreg(rs2 + vx)?;
                                        self.vreg.store_vreg(rd + vx, val)?;
                                    }
                                } else {
                                    bail!(UnimplementedInstruction("vsmul"));
                                }
                            }
                            0b000000 => {
                                
                                if (!vm) && rd == 0 {
                                    bail!(UnsupportedParam("Can't handle vadd on the mask register, because it uses the mask register :)".to_string()));
                                }
                                for i in self.vstart..self.vl {
                                    if !self.vreg.seg_masked_out(vm, i) {
                                        let val = self.vreg.load_vreg_elem_int(self.vtype.vsew, rs2, i)?;
                                        
                                        let val = match self.vtype.vsew {
                                            Sew::e8 => {
                                                (val as u8).wrapping_add(imm as u8) as u128
                                            }
                                            Sew::e16 => {
                                                (val as u16).wrapping_add(imm as u16) as u128
                                            }
                                            Sew::e32 => {
                                                (val as u32).wrapping_add(imm as u32) as u128
                                            }
                                            Sew::e64 => {
                                                (val as u64).wrapping_add(imm as u64) as u128
                                            }
                                            Sew::e128 => {
                                                (val as u128).wrapping_add(imm as u128) as u128
                                            }
                                        };
                                        self.vreg.store_vreg_elem_int(self.vtype.vsew, rd, i, val)?;
                                    }
                                }
                            }
                            _ => bail!(MiscDecodeException(format!(
                                    "Vector arithmetic funct3 {:03b} funct6 {:06b} not yet handled", funct3, funct6)
                            ))
                        }
                    }
                    _ => bail!(UnsupportedParam(format!("Vector arithmetic funct3 {:03b} currently not supported", funct3)))
                }
            }
            (LoadFP | StoreFP, InstructionBits::FLdStType{rd, rs1, rs2, vm, ..}) => {
                let op = DecodedMemOp::decode_load_store(opcode, inst, self.vtype, self.vl, sreg)?;
                
                if op.dir() == MemOpDir::Load && (!vm) && rd == 0 {
                    
                    bail!("Masked instruction cannot load into v0");
                }
                
                if op.evl() <= self.vstart {
                    println!("EVL {} <= vstart {} => vector {:?} is no-op", op.evl(), self.vstart, op.dir());
                    return Ok(None)
                }
                let addr_provenance = sreg.get_addr_provenance(rs1)?;
                
                
                
                
                let (fast_check_result, addr_range) = self.fast_check_load_store(addr_provenance, rs2, vm, op, sreg);
                match fast_check_result {
                    
                    Ok(true) => {
                        self.exec_load_store(addr_range, rd, rs1, rs2, vm, op, sreg, mem)
                            .context("Executing pre-checked vector access - shouldn't throw CapabilityExceptions under any circumstances")
                    },
                    
                    Err(e) => {
                        
                        Err(e)
                    }
                    
                    Ok(false) => {
                        self.exec_load_store(addr_range, rd, rs1, rs2, vm, op, sreg, mem)
                            .context("Executing not-pre-checked vector access - may throw CapabilityException")
                    },
                }.context(format!("Executing vector access {:?}", op))?;
            }
            _ => bail!("Unexpected opcode/InstructionBits pair at vector unit")
        }
        
        
        self.vstart = 0;
        Ok(None)
    }
}
impl<uXLEN: PossibleXlen, TElem> CSRProvider<uXLEN> for Rvv<uXLEN, TElem> {
    fn has_csr(&self, csr: u32) -> bool {
        match csr {
            
            0x008 | 0x009 | 0x00A | 0x00F => todo!(),
            0xC20 | 0xC21 | 0xC22 => true,
            _ => false
        }
    }
    fn csr_atomic_read_write(&mut self, csr: u32, _need_read: bool, _write_val: uXLEN) -> Result<Option<uXLEN>> {
        match csr {
            0xC20 | 0xC21 | 0xC22 => bail!("CSR 0x{:04x} is read-only, cannot atomic read/write", csr),
            _ => todo!()
        }
    }
    fn csr_atomic_read_set(&mut self, csr: u32, set_bits: Option<uXLEN>) -> Result<uXLEN> {
        if set_bits != None {
            match csr {
                0xC20 | 0xC21 | 0xC22 => bail!("CSR 0x{:04x} is read-only, cannot atomic set", csr),
                _ => todo!()
            }
        } else {
            match csr {
                0xC20 => Ok(self.vl.into()),
                0xC21 => Ok(self.vtype.encode().into()),
                0xC22 => Ok(((VLEN/8) as u32).into()),
                _ => todo!()
            }
        }
    }
    fn csr_atomic_read_clear(&mut self, _csr: u32, _clear_bits: Option<uXLEN>) -> Result<uXLEN> {
        todo!()
    }
}