SSE Branching: Performance Comparison

CS 641 Lecture, Dr. Lawlor

First, a surprising result comparing SSE 'branch' with the branch prediction in modern CPUs: 441 lecture: branch_vs_SSE.  Unpredictable branches cost up to 7x performance on modern machines, but SSE is exactly the same performance whether branches are predictable or not--because it doesn't try to predict branches!

movemask and SSE Branching

Let's start with the Mandelbrot set fractal:

#include <iostream>
#include <fstream> /* for ofstream */
#include <complex> /* for fractal arithmetic */

/**
A linear function in 2 dimensions: returns a double as a function of (x,y).
*/
class linear2d_function {
public:
double a,b,c;
void set(double a_,double b_,double c_) {a=a_;b=b_;c=c_;}
linear2d_function(double a_,double b_,double c_) {set(a_,b_,c_);}
double evaluate(double x,double y) const {return x*a+y*b+c;}
};

int foo(void)
{
// Figure out how big an image we should render:
int wid=350, ht=256;

// Create a PPM output image file header:
std::ofstream out("out.ppm",std::ios_base::binary);
out<<"P6\n"
<<wid<<" "<<ht<<"\n"
<<"255\n";

// Set up coordinate system to render the Mandelbrot Set:
double scale=3.0/wid;
linear2d_function fx(scale,0.0,-scale*wid/2); // returns c given pixels
linear2d_function fy(0.0,scale,-1.0);

for (int y=0;y<ht;y++)
for (int x=0;x<wid;x++) {
/* Walk this Mandelbrot Set pixel */
typedef std::complex<double> COMPLEX;
COMPLEX c(fx.evaluate(x,y),fy.evaluate(x,y));
COMPLEX z(0.0);
int count;
enum {max_count=100};
for (count=0;count<max_count;count++) {
z=z*z+c;
if ((z.real()*z.real()+z.imag()*z.imag())>=4.0) break;
}

/* Figure out the output pixel color */
unsigned char r,g,b;
r=(unsigned char)(z.real()*(256/2.0));
g=(unsigned char)(z.imag()*(256/2.0));
b=(unsigned char)(((z.real()*z.real()+z.imag()*z.imag()))*256);
out<<r<<g<<b;
}

return 0;
}

(Try this in NetRun now!)

This takes 57.8ms to render a 350x256 PPM image.  The first thing to note is that this line is sequential:
		out<<r<<g<<b;
We can't write to the output file in parallel, so we'll have to collect the output values into an array of "pixel" objects first:
class pixel { /* onscreen RGB pixel */
public:
unsigned char r,g,b;
pixel() {}
pixel(const COMPLEX &z) {
r=(unsigned char)(z.real()*(256/2.0));
g=(unsigned char)(z.imag()*(256/2.0));
b=(unsigned char)(((z.real()*z.real()+z.imag()*z.imag()))*256);
}
};

int foo(void)
{
...
pixel *output=new pixel[wid*ht];

for (int y=0;y<ht;y++)
for (int x=0;x<wid;x++) {
...

output[x+y*wid]=pixel(z);
}
out.write((char *)&output[0],sizeof(pixel)*wid*ht);

return 0;
}

(Try this in NetRun now!)

Doing the output in one big block is nearly twice as fast--just 29.8ms.  Man!  Even thinking about running in parallel sped up this program substantially!  I think this is mostly due to the fact that the inner loop now has no function calls, so the compiler doesn't have to save and restore registers across the I/O function call.

SSE

Let's try to use SSE.  Here's the inner loop:
	COMPLEX c(fx.evaluate(x,y),fy.evaluate(x,y));
COMPLEX z(0.0);
enum {max_count=100};
for (count=0;count<max_count;count++) {
z=z*z+c;
if ((z.real()*z.real()+z.imag()*z.imag())>=4.0) break;
}
output[x+y*wid]=pixel(z);
That "z*z" line is complex multiplication, so it's really "z=COMPLEX(z.r*z.r-z.i*z.i,2*z.r*z.i);".   Let's first rewrite this in terms of basic floating-point operations:
	COMPLEX c(fx.evaluate(x,y),fy.evaluate(x,y));
float cr=c.real(), ci=c.imag();
float zr=0.0, zi=0.0;
int count;
enum {max_count=100};
for (count=0;count<max_count;count++) {
float tzr=zr*zr-zi*zi+cr; /*subtle: don't overwrite zr yet!*/
float tzi=2*zr*zi+ci;
zr=tzr; zi=tzi;
if ((zr*zr+zi*zi)>=4.0) break;
}
output[x+y*wid]=pixel(COMPLEX(zr,zi));

(Try this in NetRun now!)

Surprisingly, again, this speeds up the program!  Now we're at 20ms.  We have two options for SSE:
  1. Put zr and zi into a single vector, and try to extract a little parallelism within the operations on it, like the duplicate multiplies.  Most of our work would be in shuffling values around.
  2. Put four different zr's into one vector, and four different zi's into another, then do the *exact* operations above for four separate pixels at once. 
Generally, 1 is a little easier (for example, only one pixel at a time going around the loop), but 2 usually gives much higher performance.  To start with, we can just unroll the x loop 4 times, load everything over into 4-vectors (named in ALL CAPS), and use the handy little instruction _mm_movemask_ps to extract the sign bits from the 4 floats to figure out when to stop the loop:
	float cr[4], ci[4];
for (int i=0;i<4;i++) {
cr[i]=fx.evaluate(x+i,y);
ci[i]=fy.evaluate(x+i,y);
}
float zero[4]={0.0,0.0,0.0,0.0};
float two[4]={2.0,2.0,2.0,2.0};
__m128 TWO=_mm_load_ps(two);
float four[4]={4.0,4.0,4.0,4.0};
__m128 FOUR=_mm_load_ps(four);
__m128 CR=_mm_load_ps(cr), CI=_mm_load_ps(ci);
__m128 zr=_mm_load_ps(zero), zi=_mm_load_ps(zero);
int count;
enum {max_count=100};
for (count=0;count<max_count;count++) {
__m128 tzr=zr*zr-zi*zi+CR; /*subtle: don't overwrite zr yet!*/
__m128 tzi=TWO*zr*zi+CI;
zr=tzr; zi=tzi;
if (_mm_movemask_ps(zr*zr+zi*zi-FOUR)==0) break;
}
_mm_store_ps(cr,zr); _mm_store_ps(ci,zi);
for (int i=0;i<4;i++) {
output[x+i+y*wid]=pixel(COMPLEX(cr[i],ci[i]));
}

(Try this in NetRun now!)

This runs, and it's really super fast, at just 8.5 ms, but it gets the wrong answer.  The trouble is that all four floats are chained together in the loop until they're all done.  What we need is for only the floats that are less than 4.0 to keep going around the loop, and the other floats to stop changing.  We can use the standard (yet weird) SSE branch trick to fix this, essentially doing the loop work only if we should keep branching, like "if (tZR*tZR+tZI*tZI<FOUR) {ZR=tZR; ZI=tZI;}".
	float cr[4], ci[4];
for (int i=0;i<4;i++) {
cr[i]=fx.evaluate(x+i,y);
ci[i]=fy.evaluate(x+i,y);
}
float zero[4]={0.0,0.0,0.0,0.0};
float two[4]={2.0,2.0,2.0,2.0};
__m128 TWO=_mm_load_ps(two);
float four[4]={4.0,4.0,4.0,4.0};
__m128 FOUR=_mm_load_ps(four);
__m128 CR=_mm_load_ps(cr), CI=_mm_load_ps(ci);
__m128 ZR=_mm_load_ps(zero), ZI=_mm_load_ps(zero);
int count;
enum {max_count=100};
for (count=0;count<max_count;count++) {
__m128 tZR=ZR*ZR-ZI*ZI+CR; /*subtle: don't overwrite ZR yet!*/
__m128 tZI=TWO*ZR*ZI+CI;
__m128 LEN=tZR*tZR+tZI*tZI;
__m128 MASK=_mm_cmplt_ps(LEN,FOUR); /* set if len<4.0 */
ZR=_mm_or_ps(_mm_and_ps(MASK,tZR),_mm_andnot_ps(MASK,ZR));
ZI=_mm_or_ps(_mm_and_ps(MASK,tZI),_mm_andnot_ps(MASK,ZI));
if (_mm_movemask_ps(LEN-FOUR)==0) break; /* everybody's > 4 */
}
_mm_store_ps(cr,ZR); _mm_store_ps(ci,ZI);
for (int i=0;i<4;i++) {
output[x+i+y*wid]=pixel(COMPLEX(cr[i],ci[i]));
}

(Try this in NetRun now!)

This gets the right answer!  It is a tad slower, 10.9ms.  But getting the right answer is non-negotiable!