for (int i=0;i<n;i++) {(Try this in NetRun now!)
if (vec[i]<7)
vec[i]=vec[i]*a+b;
else
vec[i]=c;
}
for (int i=0;i<n;i++) {Written in ordinary sequential code, this is actually a slowdown, not a speedup! But in SSE this branch-to-logical transformation means you can keep barreling along in parallel, without having to switch to sequential floating point to do the branches:
unsigned int mask=(vec[i]<7)?0xffFFffFF:0;
vec[i]=((vec[i]*a+b)&mask) | (c&~mask);
}
__m128 A=_mm_load1_ps(&a), B=_mm_load1_ps(&b), C=_mm_load1_ps(&c);This gives about a 3.8x speedup over the original loop on my machine!
__m128 Thresh=_mm_load1_ps(&thresh);
for (int i=0;i<n;i+=4) {
__m128 V=_mm_load_ps(&vec[i]);
__m128 mask=_mm_cmplt_ps(V,Thresh); // Do all four comparisons
__m128 V_then=_mm_add_ps(_mm_mul_ps(V,A),B); // "then" half of "if"
__m128 V_else=C; // "else" half of "if"
V=_mm_or_ps( _mm_and_ps(mask,V_then), _mm_andnot_ps(mask,V_else) );
_mm_store_ps(&vec[i],V);
}
#include <xmmintrin.h>
class fourfloats; /* forward declaration */
/* Wrapper around four bitmasks: 0 if false, all-ones (NAN) if true.
This class is used to implement comparisons on SSE registers.
*/
class fourmasks {
__m128 mask;
public:
fourmasks(__m128 m) {mask=m;}
__m128 if_then_else(fourfloats dthen,fourfloats delse);
};
/* Nice wrapper around __m128:
it represents four floating point values. */
class fourfloats {
__m128 v;
public:
fourfloats(float onevalue) { v=_mm_load1_ps(&onevalue); }
fourfloats(__m128 ssevalue) {v=ssevalue;} // put in an SSE value
operator __m128 () const {return v;} // take out an SSE value
fourfloats(const float *fourvalues) { v=_mm_load_ps(fourvalues);}
void store(float *fourvalues) {_mm_store_ps(fourvalues,v);}
/* arithmetic operations return blocks of floats */
fourfloats operator+(const fourfloats &right) {
return _mm_add_ps(v,right.v);
}
/* comparison operations return blocks of masks (bools) */
fourmasks operator<(const fourfloats &right) {
return _mm_cmplt_ps(v,right.v);
}
};
inline __m128 fourmasks::if_then_else(fourfloats dthen,fourfloats delse) {
return _mm_and_ps(mask,dthen)+
_mm_andnot_ps(mask,delse);
}
float src[4]={1.0,5.0,3.0,4.0};
float dest[4];
int foo(void) {
/*
// Serial code
for (int i=0;i<4;i++) {
if (src[i]<4.0) dest[i]=src[i]*2.0; else dest[i]=17.0;
}
*/
// Parallel code
fourfloats s(src);
fourfloats d=(s<4.0).if_then_else(s+s,17.0);
d.store(dest);
//farray_print(dest,4); // <- for debugging
return 0;
}
Compilers are now good enough that there is zero speed penalty due to the nice "fourfloats" class: the nice syntax comes for free!
Here's the slightly better developed wrapper we developed in lecture:#include <pmmintrin.h> /* for SSE3 hadd_ps */
// Secret stupid class: do not use...
class not_vec4 {
__m128 v; // bitwise inverse of our value (!!)
public:
not_vec4(__m128 val) {v=val;}
__m128 get(void) const {return v;} // returns INVERSE of our value (!!)
};
// This is the one to use!
class vec4 {
__m128 v;
public:
vec4(__m128 val) {v=val;}
vec4(const float *src) {v=_mm_loadu_ps(src);}
vec4(float x) {v=_mm_set_ps1(x);}
vec4 operator+(const vec4 &rhs) const {return _mm_add_ps(v,rhs.v);}
vec4 operator-(const vec4 &rhs) const {return _mm_sub_ps(v,rhs.v);}
vec4 operator*(const vec4 &rhs) const {return _mm_mul_ps(v,rhs.v);}
vec4 operator/(const vec4 &rhs) const {return _mm_div_ps(v,rhs.v);}
vec4 operator&(const vec4 &rhs) const {return _mm_and_ps(v,rhs.v);}
vec4 operator|(const vec4 &rhs) const {return _mm_or_ps(v,rhs.v);}
vec4 operator^(const vec4 &rhs) const {return _mm_xor_ps(v,rhs.v);}
vec4 operator==(const vec4 &rhs) const {return _mm_cmpeq_ps(v,rhs.v);}
vec4 operator!=(const vec4 &rhs) const {return _mm_cmpneq_ps(v,rhs.v);}
vec4 operator<(const vec4 &rhs) const {return _mm_cmplt_ps(v,rhs.v);}
vec4 operator<=(const vec4 &rhs) const {return _mm_cmple_ps(v,rhs.v);}
vec4 operator>(const vec4 &rhs) const {return _mm_cmpgt_ps(v,rhs.v);}
vec4 operator>=(const vec4 &rhs) const {return _mm_cmpge_ps(v,rhs.v);}
not_vec4 operator~(void) const {return not_vec4(v);}
__m128 get(void) const {return v;}
float *store(float *ptr) {
_mm_store_ps(ptr,v);
return ptr;
}
float &operator[](int index) { return ((float *)&v)[index]; }
float operator[](int index) const { return ((const float *)&v)[index]; }
friend ostream &operator<<(ostream &o,const vec4 &y) {
o<<y[0]<<" "<<y[1]<<" "<<y[2]<<" "<<y[3];
return o;
}
friend vec4 operator&(const vec4 &lhs,const not_vec4 &rhs)
{return _mm_andnot_ps(rhs.get(),lhs.get());}
friend vec4 operator&(const not_vec4 &lhs, const vec4 &rhs)
{return _mm_andnot_ps(lhs.get(),rhs.get());}
vec4 if_then_else(const vec4 &then,const vec4 &else_part) const {
return _mm_or_ps( _mm_and_ps(v,then.v),
_mm_andnot_ps(v, else_part.v)
);
}
};
vec4 sqrt(const vec4 &v) {
return _mm_sqrt_ps(v.get());
}
vec4 rsqrt(const vec4 &v) {
return _mm_rsqrt_ps(v.get());
}
/* Return value = dot product of a & b, replicated 4 times */
inline vec4 dot(const vec4 &a,const vec4 &b) {
vec4 t=a*b;
__m128 vt=_mm_hadd_ps(t.get(),t.get());
vt=_mm_hadd_ps(vt,vt);
return vt;
}
float data[4]={1.2,2.3,3.4,1.5};
vec4 a(data);
vec4 b(1.0);
vec4 c(0.0);
int foo(void) {
c=b*dot(a,b);
return 0;
}