Da li je vreme da se uci aarch64 asembler?

Ucim instrukcije, vrlo je obimno, ima bar 100 do 400 instrukcija, pausalno. No hvala liku sto je dokumentovao bar.
Izabrao sam da uradim insti4nsics u Rustu, kao jeziku koji najvise obecava od novih, neki lik je zapoceo
ali je uradio samo load/stere i jednu jedinu intrukciju :heart:
I sad nema deljenje i korenovanje, to cu isto morati da uradim preko mnozenja i aproksimacije polinomom.
L|ikovi su to pravili za neuronske mreze, pa nema ovih operacija...
 
I evo primera prve instrukcije, (Rust).
fma64, najkorisnija instrukcija, koja mnozi x*y+z i to ide po indeksu po redovima. Dakle da bi izmonizili sve redove, potrebno
je pozvati instrukciju 8 puta. U primeru pozivam za prvi i zadnji red.
Kod:
use amx::{prelude::*, XBytes, XRow, YBytes, YRow, ZRow};


fn main() {
    unsafe {
        let mut ctx = amx::AmxCtx::new().unwrap();

        let in_x: Vec<u16> = vec![1;256];
        let in_y: Vec<u16> = vec![3;256];
        let mut in_xf: Vec<f64> = vec![1.0;64];
        let mut in_yf: Vec<f64> = vec![3.0;64];
        let in_zf: Vec<f64> = vec![2.0;64*8];
        for i in 0..64 {
          for j in 0..8{
            in_xf[i] += i as f64;
            in_yf[i] += i as f64;
          }
        }
        ctx.clear();
        ctx.set0();

        for i in 0..8 {
            //ctx.load512(&in_x[i * 32], XRow(i));
            //ctx.load512(&in_y[i * 32], YRow(i));
            ctx.load512(&in_xf[i*8], XRow(i));
            ctx.load512(&in_yf[i*8], YRow(i));
        }
        for i in 0..64 {
            ctx.load512(&in_zf[i*8], ZRow(i));
        }

//        println!("x = {:?}", *(in_x.as_ptr() as *const [[u16; 32]; 8]));
//        println!("y = {:?}", *(in_y.as_ptr() as *const [[u16; 32]; 8]));
       let got_x = std::mem::transmute::<_,[[f64;8];8]>(ctx.read_x());
       let got_y = std::mem::transmute::<_,[[f64;8];8]>(ctx.read_y());
       println!("X");
       printA::<8,8>(&got_x);
       println!("Y");
       printA::<8,8>(&got_y);
/*
            ctx.outer_product_u32_xy_to_z(
                Some(XBytes(x_offset)),
                Some(YBytes(y_offset)),
                ZRow(z_index),
                false, // don't accumulate
            );
            ctx.reduce_u32_to_z();
*/
            ctx.fma64_z(0);
            ctx.fma64_z(7);
/*
*/
//            let got_z = std::mem::transmute::<_,[[u32;16];64]>(ctx.read_z());
            let got_z = std::mem::transmute::<_,[[f64;8];64]>(ctx.read_z());
            println!("Z");
            printA::<64,8>(&got_z);

    }
}
fn printA<const rows:usize,const cols:usize>(a:&[[f64;cols];rows]){
  for i in 0..rows {
    println!("{:?}", a[i])
  }
}

izlaz izgleda ovako:
Kod:
X
[1.0, 9.0, 17.0, 25.0, 33.0, 41.0, 49.0, 57.0]
[65.0, 73.0, 81.0, 89.0, 97.0, 105.0, 113.0, 121.0]
[129.0, 137.0, 145.0, 153.0, 161.0, 169.0, 177.0, 185.0]
[193.0, 201.0, 209.0, 217.0, 225.0, 233.0, 241.0, 249.0]
[257.0, 265.0, 273.0, 281.0, 289.0, 297.0, 305.0, 313.0]
[321.0, 329.0, 337.0, 345.0, 353.0, 361.0, 369.0, 377.0]
[385.0, 393.0, 401.0, 409.0, 417.0, 425.0, 433.0, 441.0]
[449.0, 457.0, 465.0, 473.0, 481.0, 489.0, 497.0, 505.0]
Y
[3.0, 11.0, 19.0, 27.0, 35.0, 43.0, 51.0, 59.0]
[67.0, 75.0, 83.0, 91.0, 99.0, 107.0, 115.0, 123.0]
[131.0, 139.0, 147.0, 155.0, 163.0, 171.0, 179.0, 187.0]
[195.0, 203.0, 211.0, 219.0, 227.0, 235.0, 243.0, 251.0]
[259.0, 267.0, 275.0, 283.0, 291.0, 299.0, 307.0, 315.0]
[323.0, 331.0, 339.0, 347.0, 355.0, 363.0, 371.0, 379.0]
[387.0, 395.0, 403.0, 411.0, 419.0, 427.0, 435.0, 443.0]
[451.0, 459.0, 467.0, 475.0, 483.0, 491.0, 499.0, 507.0]
Z
[5.0, 29.0, 53.0, 77.0, 101.0, 125.0, 149.0, 173.0]
[2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0]
[2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0]
[2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0]
[2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0]
[2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0]
[2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0]
[202501.0, 206109.0, 209717.0, 213325.0, 216933.0, 220541.0, 224149.0, 227757.0]
[13.0, 101.0, 189.0, 277.0, 365.0, 453.0, 541.0, 629.0]
[2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0]
[2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0]
[2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0]
[2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0]
[2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0]
[2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0]
[206093.0, 209765.0, 213437.0, 217109.0, 220781.0, 224453.0, 228125.0, 231797.0]
[21.0, 173.0, 325.0, 477.0, 629.0, 781.0, 933.0, 1085.0]
[2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0]
[2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0]
[2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0]
[2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0]
[2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0]
[2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0]
[209685.0, 213421.0, 217157.0, 220893.0, 224629.0, 228365.0, 232101.0, 235837.0]
[29.0, 245.0, 461.0, 677.0, 893.0, 1109.0, 1325.0, 1541.0]
[2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0]
[2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0]
[2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0]
[2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0]
[2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0]
[2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0]
[213277.0, 217077.0, 220877.0, 224677.0, 228477.0, 232277.0, 236077.0, 239877.0]
[37.0, 317.0, 597.0, 877.0, 1157.0, 1437.0, 1717.0, 1997.0]
[2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0]
[2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0]
[2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0]
[2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0]
[2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0]
[2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0]
[216869.0, 220733.0, 224597.0, 228461.0, 232325.0, 236189.0, 240053.0, 243917.0]
[45.0, 389.0, 733.0, 1077.0, 1421.0, 1765.0, 2109.0, 2453.0]
[2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0]
[2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0]
[2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0]
[2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0]
[2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0]
[2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0]
[220461.0, 224389.0, 228317.0, 232245.0, 236173.0, 240101.0, 244029.0, 247957.0]
[53.0, 461.0, 869.0, 1277.0, 1685.0, 2093.0, 2501.0, 2909.0]
[2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0]
[2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0]
[2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0]
[2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0]
[2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0]
[2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0]
[224053.0, 228045.0, 232037.0, 236029.0, 240021.0, 244013.0, 248005.0, 251997.0]
[61.0, 533.0, 1005.0, 1477.0, 1949.0, 2421.0, 2893.0, 3365.0]
[2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0]
[2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0]
[2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0]
[2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0]
[2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0]
[2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0]
[227645.0, 231701.0, 235757.0, 239813.0, 243869.0, 247925.0, 251981.0, 256037.0]
Dakle prvo ucitavamo registre, pa vrsimo operaciju za 0 i 7 indeks.
Ako nekog zanima, mogu lib da stavim na github, no mora da ima Apple M1 ili M2 procesor. :heart:
 
E sad, korenovanje preko mnozenja i sabiranja.
Gledao sam vikipediju i nasao mislim najoptimalnije resenje.
Prvo se odredi grub estimate tako sto se faktorise broj da bude 0-100.
u obiku a * 10^2n
Kada je a < 10 onda je grub koren (0.28a+0.89)*10^n
a a >= 10 (0.89a+2.8)*10^n.
Onda se primeni Njutnov metod gde je x0 taj estimate.
xn+1 = xn*(3/2 - broj/2*xn^2)
i tako idu iteracije do dovoljne preciznosti za f64
Ovo daje 1/sqrt(broj) gde se sqrt izracunava sa broj*1/sqrt(broj)
Samo faktorisanje broja ne moze paralelno, sve ostalo moze.
Cilj je paralelno izracunati vektor od 8 f64 broja.
1/broj se izracunava dosta jednostavnije, to vec imam.
 
i evo implementacije:
Kod:
fn rcp(&mut self,row:&[f64;8])->[f64;8]{
      let mut rc = [0.0;8];
      let one = [1.0;8];
      let zero = [0.0;8];
      let mut magic:[u64;8] = [0x7FDE6238502484BA;8];
      let mut zv = [0.0;8];
      for (i,mut v) in magic.iter_mut().enumerate() {
        unsafe {*v -= std::mem::transmute::<_,u64>(row[i]);}
      }
      unsafe {
        self.load512(&one,ZRow(1));
        self.load512(&magic,XRow(1));
        self.load512(row,YRow(1));
      }
      self.fms64_vec(1,1,1);
      self.extr_y(1,1);
      for _ in 0..3 {
        self.fma64_vec_x(1,1);
        self.fma64_vec(1,1,1);
        self.extr_x(1,1);
        self.extr_xy(0,1);
        unsafe {self.load512(&zero,ZRow(2));}
        self.fma64_vec(2,0,1);
        self.extr_y(2,1);
      }
      self.extr_xy(1,1);
      self.extr_y(1,1);
      self.fma64_vec(1,1,1);
      unsafe { self.store512(&mut rc,ZRow(1));}
      rc
    }
    fn sqrt(& mut self, s:&[f64;8])->[f64;8] {
      let mut rc = [0.0;8];
      let mut a = *s;
      let mut sqr10 = [1.0;8];
      let mut in_x = [0.0;8];
      let mut in_z = [0.0;8];
      let zero = [0.0;8];
      let zero_point_five = [0.5;8];
      unsafe {
          self.load512(s,XRow(1));
          self.load512(&zero_point_five,YRow(1));
          self.load512(&zero,ZRow(1));
      }
      self.fma64_vec(1,1,1);
      self.extr_x(1,1);
      for (i,mut a) in a.iter_mut().enumerate() {
        while *a > 100.0 {
          *a *= 0.001;
          sqr10[i]*= 10.0;
        }
        if *a < 1.0 { *a = 1.0 }
      }
      for (i,a) in a.iter().enumerate(){
        (in_x[i],in_z[i]) = if *a < 10.0 {
          (0.28, 0.89)
        } else {
          (0.89, 2.8)
        }
      }
      unsafe {
      self.load512(&in_x,XRow(0));
      self.load512(&a,YRow(0));
      self.load512(&in_z,ZRow(0));
      }
      self.fma64_vec(0,0,0);
      self.extr_x(0,0);
      unsafe {
      self.load512(&sqr10,YRow(0));
      self.load512(&zero,ZRow(0));
      }
      self.fma64_vec(0,0,0); // we have estimate
      unsafe {
        self.load512(&zero_point_five,XRow(7));
        self.store512(&mut a,ZRow(0));;
      }
      for i in 0..6 {
        let rcp = self.rcp(&a);
        unsafe{
          self.load512(&a,ZRow(0));
          self.load512(s,XRow(0));
          self.load512(&rcp,YRow(0));
        }
        self.fma64_vec(0,0,0);

        self.extr_y(0,0);
        unsafe { self.load512(&zero,ZRow(0));}
        self.fma64_vec(0,7,0);
        unsafe { self.store512(&mut a,ZRow(0));}
      }
      unsafe {self.store512(&mut rc,ZRow(0));}
      rc
    }
ovako se koristi:
Kod:
            let two = [65536.0;8];
            ctx.fma64_vec(0,0,0);
            ctx.fma64_vec(7,7,7);
            for _ in 0..10000000 {
              ctx.sqrt(&two);
            }
            println!("sqrt\n{:?}",ctx.sqrt(&two));
            let rcp = [2.0;8];
            println!("rcp\n{:?}",ctx.rcp(&rcp));
 
evo, ali mora zbog toga sto AMX nema u64 i cmp instrukcije jos malo da ostane load/stora. Shaveovo sam jedno 50% performansi,
Kod:
    fn rcp(&mut self,zrow_in:u64,zrow_out:u64){
      let one = [1.0;8];
      let zero = [0.0;8];
      let mut magic:[u64;8] = [0x7FDE6238502484BA;8];
      let mut zv = [0.0;8];
      let mut row = [0;8];
      unsafe { self.store512(&mut row,ZRow(zrow_in as usize));}
      for (i,mut v) in magic.iter_mut().enumerate() {
        *v -= row[i];
      }
      unsafe {
        self.load512(&one,ZRow(1));
        self.load512(&magic,XRow(1));
        self.load512(&row,YRow(1));
        self.load512(&zero,XRow(2));
      }
      self.fms64_vec(1,1,1);
      self.extr_y(1,1);
      for _ in 0..3 {
        self.fma64_vec_x(1,1);
        self.fma64_vec(1,1,1);
        self.extr_x(1,1);
        self.extr_xy(0,1);
        self.fma64_vec_x(2,2);
        self.fma64_vec(2,0,1);
        self.extr_y(2,1);
      }
      self.extr_xy(1,1);
      self.extr_y(1,1);
      self.fma64_vec(1,1,1);
      self.extr_x(1,0);
      self.fma64_vec_x(zrow_out,0);
    }
    fn sqrt(& mut self, zrow_in:u64,zrow_out:u64){
      let mut a = [0.0;8];
      unsafe {self.store512(&mut a,ZRow(zrow_in as usize));}
      let mut sqr10 = [1.0;8];
      let mut in_x = [0.0;8];
      let mut in_z = [0.0;8];
      let zero = [0.0;8];
      let zero_point_five = [0.5;8];
      unsafe {
          //self.load512(s,XRow(1));
          self.extr_x(zrow_in,1);
          self.load512(&zero_point_five,YRow(1));
          self.load512(&zero,ZRow(1));
          self.load512(&zero,ZRow(63));
          //self.load512(s,ZRow(62));
          self.extr_x(zrow_in,0);
          self.fma64_vec_x(62,0);
          self.load512(&zero_point_five,ZRow(61));
      }
      self.fma64_vec(1,1,1);
      self.extr_x(1,1);
      for (i,mut a) in a.iter_mut().enumerate() {
        while *a > 100.0 {
          *a *= 0.001;
          sqr10[i]*= 10.0;
        }
        if *a < 1.0 { *a = 1.0 }
      }
      for (i,a) in a.iter().enumerate(){
        (in_x[i],in_z[i]) = if *a < 10.0 {
          (0.28, 0.89)
        } else {
          (0.89, 2.8)
        }
      }
      unsafe {
      self.load512(&in_x,XRow(0));
      self.load512(&a,YRow(0));
      self.load512(&in_z,ZRow(0));
      }
      self.fma64_vec(0,0,0);
      self.extr_x(0,0);
      unsafe {
      self.load512(&sqr10,YRow(0));
      self.load512(&zero,ZRow(0));
      }
      self.fma64_vec(0,0,0); // we have estimate
      self.extr_x(61,7);
//        self.store512(&mut a,ZRow(0));;
      for i in 0..6 {
        self.extr_x(0,6);
        self.fma64_vec_x(60,6);
        self.rcp(60,60);
          //self.load512(&a,ZRow(0));
        self.extr_y(60,0);
        self.extr_x(62,0);
          //self.load512(&rcp,YRow(0));
        self.fma64_vec(0,0,0);

        self.extr_y(0,0);
        self.extr_x(63,0);
        self.fma64_vec_x(0,0);
        self.fma64_vec(0,7,0);
        //unsafe { self.store512(&mut a,ZRow(0));}
      }
      //unsafe {self.store512(&mut rc,ZRow(0));}
      self.extr_x(0,0);
      self.fma64_vec_x(zrow_out,0);
    }
i sad je malo komplikovanije za koriscenje:
Kod:
            let mut two:[f64;8] = [2.0;8];
            let two1 = two;
            ctx.fma64_vec(0,0,0);
            ctx.fma64_vec(7,7,7);
                 let start = Instant::now();
                 let mut sum = 0.0;
            ctx.load512(&two,ZRow(50));
            for _ in 0..1000000 {
              ctx.sqrt(50,51);
              ctx.store512(&mut two,ZRow(51));
              sum+=two.iter().sum::<f64>();
            }
                 let end = start.elapsed();
                 let diff = (end.as_secs()*1000000000+end.subsec_nanos() as u64) as f64 / 1000000000.0;
                 println!("simd time {} sum {}",diff, sum);
                 let start = Instant::now();
                 let mut sum = 0.0;
            for _ in 0..1000000 {
              for v in two1 {
                sum+=v.sqrt();
              }
            }
                 let end = start.elapsed();
                 let diff = (end.as_secs()*1000000000+end.subsec_nanos() as u64) as f64 / 1000000000.0;
                 println!("seq time {} sum {}",diff,sum);
            ctx.sqrt(50,51);
            ctx.store512(&mut two,ZRow(51));
            println!("sqrt\n{:?}",two);
            let mut rcp = [2.0;8];
            ctx.load512(&rcp,ZRow(63));
            ctx.rcp(63,63);
            ctx.store512(&mut rcp,ZRow(63));
            println!("rcp\n{:?}",rcp);
Praviti asembler od ovoga nema smisla nesto, preveliki posao, jer treba sve te instrukcije binarno iskodirati...
lakse u Rustu :heart:
 
|ponosan na sebe, ubrzao sam ovoga puta znatno, totalno promenio sqrt algoritam, jer ovo gore koristi deljenje pa je sporije.
Sta sve ima na vikipediji :P
Ima tu malo i komentara ovog puta...
Kod:
    fn rcp(&mut self,zrow_in:u64,zrow_out:u64){
      let one = [1.0;8];
      let zero = [0.0;8];
      let mut magic:[u64;8] = [0x7FDE6238502484BA;8];
      let mut row = [0;8];
      unsafe { self.store512(&mut row,ZRow(zrow_in as usize));}
      for (i,mut v) in magic.iter_mut().enumerate() {
        *v -= row[i];
      }
      unsafe {
        self.load512(&one,ZRow(1));
        self.load512(&magic,XRow(1));
        self.load512(&row,YRow(1));
        self.load512(&zero,XRow(2));
      }
      self.fms64_vec(1,1,1);
      self.extr_y(1,1);
      for _ in 0..3 {
        self.fma64_vec_x(1,1);
        self.fma64_vec(1,1,1);
        self.extr_x(1,1);
        self.extr_xy(0,1);
        self.fma64_vec_x(2,2);
        self.fma64_vec(2,0,1);
        self.extr_y(2,1);
      }
      self.extr_x(1,1);
      self.fma64_vec(1,1,1);
      self.extr_x(1,0);
      self.fma64_vec_x(zrow_out,0);
    }
    fn sqrt(& mut self, zrow_in:u64,zrow_out:u64){
      let mut a = [0.0f32;8];
      let mut number = [0.0f64;8];
      unsafe {self.store512(&mut number,ZRow(zrow_in as usize));}
      for (ind,v) in number.iter().enumerate() {
        a[ind] = *v as f32;
      }
      let mut i = [0u32;8];
      for (ind,v) in a.iter().enumerate() {
        unsafe {i[ind] = std::mem::transmute::<_,u32>(*v);}
      }
      for mut i in i.iter_mut() {
        *i = 0x5f3759df - (*i >> 1);
      }
      for (ind,v) in i.iter().enumerate() {
        unsafe{ a[ind] = std::mem::transmute::<_,f32>(*v);}
      }
      for mut v in a.iter_mut() {
        *v = *v * ( 1.5 - ( 0.5 * *v * *v * *v));
      }
      for (ind,v) in a.iter().enumerate() {
        number[ind] = *v as f64;
      }
      let zero = [0.0f64;8];
      let three = [3.0f64;8];
      let zero_point_five = [0.5f64;8];
      unsafe {
        self.load512(&number,ZRow(60));
        self.load512(&zero,ZRow(63));
        self.load512(&three,ZRow(62));
        self.load512(&zero_point_five,ZRow(61));
      }
      for _ in 0..8 {
        self.extr_y(60,0);// xn -> X
        self.extr_x(60,0);// xn -> Y
        self.extr_x(63,7);
        self.fma64_vec_x(0,7);// 0 -> Z
        self.fma64_vec(0,0,0);// xn ^2
        self.extr_x(zrow_in,0);// s -> X
        self.extr_y(0,0); // Z -> Y
        self.extr_x(62,7);
        self.fma64_vec_x(0,7);// 3 -> Z
        self.fms64_vec(0,0,0);// 3 - s * xn ^ 2
        self.extr_x(60,0);// xn -> X
        self.extr_y(0,0);// Z -> Y
        self.extr_x(63,7);
        self.fma64_vec_x(0,7); // 0 -> Z
        self.fma64_vec(0,0,0);// xn * (3- s * xn ^2)
        self.extr_x(0,0);// Z -> X
        self.extr_y(61,0); // 0.5 -> Y
        self.extr_x(63,7);
        self.fma64_vec_x(0,7);// 0 -> Z
        self.fma64_vec(0,0,0);// xn * (3 - s * xn^2)/2
        self.extr_x(0,0);
        self.fma64_vec_x(60,0); // result -> Z[60]
      }
      self.extr_y(zrow_in,0);// s -> Y
      self.extr_x(63,7);
      self.fma64_vec_x(zrow_out,7);// 0 -> Z
      self.fma64_vec(zrow_out,0,0);// s * 1/sqrt(s)
    }
Nisam hteo da implementiram u AMX prvi priblizni broj racunanje, jer opet AMX nema 64 bitnu int aritmetiku, a ta jedna formula,
ne uzima tolko vremena...
 
Jos optimizacije sqrt funkcije
Kod:
    fn sqrt(& mut self, zrow_in:u64,zrow_out:u64){
      let mut a = [0.0f32;8];
      let mut number = [0.0f64;8];
      unsafe {self.store512(&mut number,ZRow(zrow_in as usize));}
      for (ind,v) in number.iter().enumerate() {
        a[ind] = *v as f32;
      }
      let mut i = [0u32;8];
      for (ind,v) in a.iter().enumerate() {
        unsafe {i[ind] = std::mem::transmute::<_,u32>(*v);}
      }
      for mut i in i.iter_mut() {
        *i = 0x5f3759df - (*i >> 1);
      }
      for (ind,v) in i.iter().enumerate() {
        unsafe{ a[ind] = std::mem::transmute::<_,f32>(*v);}
      }
      for (ind,v) in a.iter().enumerate() {
        number[ind] = *v as f64;
      }
      let zero = [0.0f64;8];
      let three = [3.0f64;8];
      let zero_point_five = [0.5f64;8];
      unsafe {
        self.load512(&number,ZRow(60));
        self.load512(&zero,ZRow(63));
        self.load512(&three,ZRow(62));
        self.load512(&zero_point_five,ZRow(61));
      }
      for _ in 0..4 {
        self.extr_y(60,0);// xn -> X
        self.extr_x(60,0);// xn -> Y
        self.extr_x(63,7);
        self.fma64_vec_x(0,7);// 0 -> Z
        self.fma64_vec(0,0,0);// xn ^2
        self.extr_x(zrow_in,0);// s -> X
        self.extr_y(0,0); // Z -> Y
        self.extr_x(62,7);
        self.fma64_vec_x(0,7);// 3 -> Z
        self.fms64_vec(0,0,0);// 3 - s * xn ^ 2
        self.extr_x(60,0);// xn -> X
        self.extr_y(0,0);// Z -> Y
        self.extr_x(63,7);
        self.fma64_vec_x(0,7); // 0 -> Z
        self.fma64_vec(0,0,0);// xn * (3- s * xn ^2)
        self.extr_x(0,0);// Z -> X
        self.extr_y(61,0); // 0.5 -> Y
        self.extr_x(63,7);
        self.fma64_vec_x(0,7);// 0 -> Z
        self.fma64_vec(0,0,0);// xn * (3 - s * xn^2)/2
        self.extr_x(0,0);
        self.fma64_vec_x(60,0); // result -> Z[60]
      }
      self.extr_y(zrow_in,0);// s -> Y
      self.extr_x(63,7);
      self.fma64_vec_x(zrow_out,7);// 0 -> Z
      self.extr_x(60,0);
      self.fma64_vec(zrow_out,0,0);// s * 1/sqrt(s)
    }
 
Pretabao u makro, i poboljsao kodiranje.
Evo kako izgleda:
Kod:
macro_rules! op_in {
{$OP:tt , $operand:tt} => {
    asm!(
        ".align 8\n.word (0x201000 + ({op} << 5) + 0{operand} - ((0{operand} >> 4) * 6))",
        op = const $OP,
        operand = in(reg) $operand
    );}
}
/// Emit an AMX instruction with a 5-bit immediate.
macro_rules!op_imm {{ $OP: tt, $OPERAND: tt}=> {
    asm!(
        ".align 8\n.word 0x00201000 + ({op} << 5) + {operand}",
        op = const $OP,
        operand = const $OPERAND
    );}
}
dakle ima dve varijante, sa immediate operandom, tj konstantom, i sa ulaznim registrem.
Varijanta sa ulaznim registrem, koji je u formatu xnn na Aarch64 mora da se pretvori
u dekadni broj jer asembler ubacuje u obliku recimo x21.
Dakle hexadekadni kako ga vidi (zato se lepi 0), treba pretvoriti u dekadni 21.
I zato ova pretumbacija. Inace instrukcija, kako vidite, je zbir konstante, operacije
koja predstavlja instrukciju, i operanda koji dolazi u run time.
Operand sadrzi sve varijante instrukcija sa izborom registara i samu operaciju
koja treba da se izvrsi.
Instrukcija je validna potpuno, sve se generise u compile time,
ali je opet jezivo sporo, mada bi trebalo da bude bar dvaput brze od Neona.
Sta sam mogao, optimizovao sam glede Rust liba, i tu se postiglo neko
ubrzanje od 15%. Naravno, nedovoljno jer je Neon 0.5 sekundi izvrsavanja
bench programa, a ovo 13.5 sekundi, sto je 27 puta sporije umesto obrnuto.
Gledam gde je overhead, ne mogu da verujem da je hardware toliko spor.
evo finalne verzije rcp i sqrt:
Kod:
   fn rcp(&mut self,zrow_in:u64,zrow_out:u64){
      let mut row = [0;8];
      let mut magic:[u64;8] = [0x7FDE6238502484BA;8];
      unsafe { self.store512(&mut row,ZRow(zrow_in as usize));}
      for (i,mut v) in magic.iter_mut().enumerate() {
        *v -= row[i];
      }
      unsafe {
        self.load512(&one,ZRow(1));
        self.load512(&magic,XRow(1));
      }
      self.extr_y(zrow_in,1);
      self.fms64_vec(1,1,1,0);
      self.extr_y(1,1);
      for _ in 0..3 {
        self.fma64_vec_x(1,1);
        self.fma64_vec(1,1,1,0);
        self.extr_x(1,1);
        self.extr_xy(0,1);
        self.fma64_vec_xy(2,0,1,0);
        self.extr_y(2,1);
      }
      self.extr_x(1,1);
      self.fma64_vec(1,1,1,0);
      self.extr_x(1,0);
      self.fma64_vec_x(zrow_out,0);
    }
   fn sqrt(& mut self, zrow_in:u64,zrow_out:u64){                                                                                                          [15/1817]
      let mut a = [0.0f32;8];
      let mut number = [0.0f64;8];
      unsafe {self.store512(&mut number,ZRow(zrow_in as usize));}
      for (ind,v) in number.iter().enumerate() {
        a[ind] = *v as f32;
      }
      for mut v in a.iter_mut() {
        unsafe {
          let mut v = std::mem::transmute::<_,*mut u32>(v);
          *v = 0x5f3759df - (*v >> 1);
        }
      }
      for (ind,v) in a.iter().enumerate() {
        number[ind] = *v as f64;
      }
      unsafe {
        self.load512(&number,ZRow(60));
        self.load512(&three,ZRow(62));
        self.load512(&zero_point_five,ZRow(61));
      }
      for _ in 0..3 {
        self.extr_y(60,0);// xn -> X
        self.extr_x(60,0);// xn -> Y
        self.fma64_vec_xy(0,0,0,0);// xn ^2
        self.extr_x(zrow_in,0);// s -> X
        self.extr_y(0,0); // Z -> Y
        self.extr_x(62,7);
        self.fma64_vec_x(0,7);// 3 -> Z
        self.fms64_vec(0,0,0,0);// 3 - s * xn ^ 2
        self.extr_x(60,0);// xn -> X
        self.extr_y(0,0);// Z -> Y
        self.fma64_vec_xy(0,0,0,0);// xn * (3- s * xn ^2)
        self.extr_x(0,0);// Z -> X
        self.extr_y(61,0); // 0.5 -> Y
        self.fma64_vec_xy(0,0,0,0);// xn * (3 - s * xn^2)/2
        self.extr_x(0,0);
        self.fma64_vec_x(60,0); // result -> Z[60]
      }
      self.extr_y(zrow_in,0);// s -> Y
      self.extr_x(60,0);
      self.fma64_vec_xy(zrow_out,0,0,0);// s * 1/sqrt(s)
    }
 
I pogledam benchmarke, vectorski fma64 ima najmanje performanse, samo 11 gflops, nasuprot
recimo matricnih koje imaju 91 gflops. Onda fma32 postize duplo od toga, a u matricnoj
varijanti oko 370gflops, sto je oko duplo brze od procesorskih klasicnih instrukcija
koje postizu oko 160gflopsa. Naravno, ako problem podelis na threadove onda
se to znatno uvecava, no mene zanima single thread...
 

Back
Top