SIMD Branching

CS 641 Lecture, Dr. Lawlor

As an aside, my lecture notes on bitwise AND are here:
http://www.cs.uaf.edu/2009/fall/cs301/lecture/09_04_bits.html

Here are some more details on writing an interpreter for machine code:
http://www.cs.uaf.edu/2009/fall/cs301/lecture/09_14_tables.html

These may help with homework problem 3 (write an interpreter).

SIMD Branching Approaches

Per-Float Branching in SSE

There are a really curious set of instructions in SSE to support per-float branching:
Compare-and-AND is actually useful to simulate branches.  The situation where these are useful is when you're trying to convert a loop like this to SSE:
	for (int i=0;i<n;i++) { 
        if (vec[i]<7)
vec[i]=vec[i]*a+b;
else
vec[i]=c;
}
(Try this in NetRun now!)

You can implement this branch by setting a mask indicating where vals[i]<7, and then using the mask to pick the correct side of the branch to squash:
	for (int i=0;i<n;i++) { 
        unsigned int mask=(vec[i]<7)?0xffFFffFF:0;
vec[i]=((vec[i]*a+b)&mask) | (c&~mask);
}
Written in ordinary sequential code, this is actually a slowdown, not a speedup!  But in SSE this branch-to-logical transformation means you can keep barreling along in parallel, without having to switch to sequential floating point to do the branches:
	__m128 A=_mm_load1_ps(&a), B=_mm_load1_ps(&b), C=_mm_load1_ps(&c);
__m128 Thresh=_mm_load1_ps(&thresh);
for (int i=0;i<n;i+=4) {
__m128 V=_mm_load_ps(&vec[i]);
__m128 mask=_mm_cmplt_ps(V,Thresh); // Do all four comparisons
__m128 V_then=_mm_add_ps(_mm_mul_ps(V,A),B); // "then" half of "if"
__m128 V_else=C; // "else" half of "if"
V=_mm_or_ps( _mm_and_ps(mask,V_then), _mm_andnot_ps(mask,V_else) );
_mm_store_ps(&vec[i],V);
}

(Try this in NetRun now!)

This gives about a 3.8x speedup over the original loop on my machine!

Intel hinted in their Larrabee paper that NVIDIA is actually doing this exact float-to-SSE branch transformation in CUDA, NVIDIA's very high-performance language for running sequential-looking code in parallel on the graphics card.

Hiding SSE Nastiness In a Wrapper Class

SSE is just ugly; comparisons doubly so.  You can hide the ugliness inside a "wrapper class":
#include <xmmintrin.h>

class fourfloats; /* forward declaration */

/* Wrapper around four bitmasks: 0 if false, all-ones (NAN) if true.
This class is used to implement comparisons on SSE registers.
*/
class fourmasks {
__m128 mask;
public:
fourmasks(__m128 m) {mask=m;}
__m128 if_then_else(fourfloats dthen,fourfloats delse);
};

/* Nice wrapper around __m128:
it represents four floating point values. */
class fourfloats {
__m128 v;
public:
fourfloats(float onevalue) { v=_mm_load1_ps(&onevalue); }

fourfloats(__m128 ssevalue) {v=ssevalue;} // put in an SSE value
operator __m128 () const {return v;} // take out an SSE value

fourfloats(const float *fourvalues) { v=_mm_load_ps(fourvalues);}
void store(float *fourvalues) {_mm_store_ps(fourvalues,v);}

/* arithmetic operations return blocks of floats */
fourfloats operator+(const fourfloats &right) {
return _mm_add_ps(v,right.v);
}

/* comparison operations return blocks of masks (bools) */
fourmasks operator<(const fourfloats &right) {
return _mm_cmplt_ps(v,right.v);
}
};

inline __m128 fourmasks::if_then_else(fourfloats dthen,fourfloats delse) {
return _mm_and_ps(mask,dthen)+
_mm_andnot_ps(mask,delse);
}

float src[4]={1.0,5.0,3.0,4.0};
float dest[4];
int foo(void) {
/*
// Serial code
for (int i=0;i<4;i++) {
if (src[i]<4.0) dest[i]=src[i]*2.0; else dest[i]=17.0;
}
*/
// Parallel code
fourfloats s(src);
fourfloats d=(s<4.0).if_then_else(s+s,17.0);
d.store(dest);

//farray_print(dest,4); // <- for debugging
return 0;
}

(Try this in NetRun now!)

Compilers are now good enough that there is zero speed penalty due to the nice "fourfloats" class: the nice syntax comes for free!

Here's the slightly better developed wrapper we developed in lecture:
#include <pmmintrin.h> /* for SSE3 hadd_ps */

// Secret stupid class: do not use...
class not_vec4 {
__m128 v; // bitwise inverse of our value (!!)
public:
not_vec4(__m128 val) {v=val;}
__m128 get(void) const {return v;} // returns INVERSE of our value (!!)
};

// This is the one to use!
class vec4 {
__m128 v;
public:
vec4(__m128 val) {v=val;}
vec4(const float *src) {v=_mm_loadu_ps(src);}
vec4(float x) {v=_mm_set_ps1(x);}

vec4 operator+(const vec4 &rhs) const {return _mm_add_ps(v,rhs.v);}
vec4 operator-(const vec4 &rhs) const {return _mm_sub_ps(v,rhs.v);}
vec4 operator*(const vec4 &rhs) const {return _mm_mul_ps(v,rhs.v);}
vec4 operator/(const vec4 &rhs) const {return _mm_div_ps(v,rhs.v);}
vec4 operator&(const vec4 &rhs) const {return _mm_and_ps(v,rhs.v);}
vec4 operator|(const vec4 &rhs) const {return _mm_or_ps(v,rhs.v);}
vec4 operator^(const vec4 &rhs) const {return _mm_xor_ps(v,rhs.v);}
vec4 operator==(const vec4 &rhs) const {return _mm_cmpeq_ps(v,rhs.v);}
vec4 operator!=(const vec4 &rhs) const {return _mm_cmpneq_ps(v,rhs.v);}
vec4 operator<(const vec4 &rhs) const {return _mm_cmplt_ps(v,rhs.v);}
vec4 operator<=(const vec4 &rhs) const {return _mm_cmple_ps(v,rhs.v);}
vec4 operator>(const vec4 &rhs) const {return _mm_cmpgt_ps(v,rhs.v);}
vec4 operator>=(const vec4 &rhs) const {return _mm_cmpge_ps(v,rhs.v);}

not_vec4 operator~(void) const {return not_vec4(v);}

__m128 get(void) const {return v;}

float *store(float *ptr) {
_mm_store_ps(ptr,v);
return ptr;
}

float &operator[](int index) { return ((float *)&v)[index]; }
float operator[](int index) const { return ((const float *)&v)[index]; }

friend ostream &operator<<(ostream &o,const vec4 &y) {
o<<y[0]<<" "<<y[1]<<" "<<y[2]<<" "<<y[3];
return o;
}
friend vec4 operator&(const vec4 &lhs,const not_vec4 &rhs)
{return _mm_andnot_ps(rhs.get(),lhs.get());}
friend vec4 operator&(const not_vec4 &lhs, const vec4 &rhs)
{return _mm_andnot_ps(lhs.get(),rhs.get());}
vec4 if_then_else(const vec4 &then,const vec4 &else_part) const {
return _mm_or_ps( _mm_and_ps(v,then.v),
_mm_andnot_ps(v, else_part.v)
);
}
};

vec4 sqrt(const vec4 &v) {
return _mm_sqrt_ps(v.get());
}
vec4 rsqrt(const vec4 &v) {
return _mm_rsqrt_ps(v.get());
}
/* Return value = dot product of a & b, replicated 4 times */
inline vec4 dot(const vec4 &a,const vec4 &b) {
vec4 t=a*b;
__m128 vt=_mm_hadd_ps(t.get(),t.get());
vt=_mm_hadd_ps(vt,vt);
return vt;
}

float data[4]={1.2,2.3,3.4,1.5};
vec4 a(data);
vec4 b(1.0);
vec4 c(0.0);

int foo(void) {
c=b*dot(a,b);
return 0;
}

(Try this in NetRun now!)