angel
mercurial changeset:
|
00001 // $Id: gmpi.hpp,v 1.4 2004/02/22 18:44:46 gottschling Exp $ 00002 /* 00003 ############################################################# 00004 # This file is part of angel released under the BSD license # 00005 # The full COPYRIGHT notice can be found in the top # 00006 # level directory of the angel distribution # 00007 ############################################################# 00008 */ 00009 00010 00011 #ifndef _gmpi_include_ 00012 #define _gmpi_include_ 00013 00014 00015 #include "mpi.h" 00016 #include <utility> 00017 #include <vector> 00018 #include <list> 00019 #include <deque> 00020 // #include <complex> 00021 00022 00023 namespace GMPI { 00024 00025 using std::pair; using std::vector; using std::list; using std::deque; 00026 00028 inline MPI::Datatype which_mpi_t (char) {return MPI::CHAR;} 00029 // inline MPI::Datatype which_mpi_t (wchar_t) {return MPI::WCHAR;} 00030 // inline MPI::Datatype which_mpi_t (signed char) {return MPI::SIGNED_CHAR;} 00031 inline MPI::Datatype which_mpi_t (signed short) {return MPI::SHORT;} 00032 inline MPI::Datatype which_mpi_t (signed int) {return MPI::INT;} 00033 inline MPI::Datatype which_mpi_t (signed long) {return MPI::LONG;} 00034 inline MPI::Datatype which_mpi_t (unsigned char) {return MPI::UNSIGNED_CHAR;} 00035 inline MPI::Datatype which_mpi_t (unsigned short) {return MPI::UNSIGNED_SHORT;} 00036 inline MPI::Datatype which_mpi_t (unsigned) {return MPI::UNSIGNED;} 00037 // inline MPI::Datatype which_mpi_t (unsigned int) {return MPI::UNSIGNED_INT;} 00038 inline MPI::Datatype which_mpi_t (unsigned long) {return MPI::UNSIGNED_LONG;} 00039 inline MPI::Datatype which_mpi_t (double) {return MPI::DOUBLE;} 00040 inline MPI::Datatype which_mpi_t (long double) {return MPI::LONG_DOUBLE;} 00041 // inline MPI::Datatype which_mpi_t (bool) {return MPI::BOOL;} 00042 // inline MPI::Datatype which_mpi_t (complex<float>) {return MPI::COMPLEX;} 00043 // inline MPI::Datatype which_mpi_t (complex<double>) {return MPI::DOUBLE_COMPLEX;} 00044 // inline MPI::Datatype which_mpi_t (complex<long double>) {return MPI::LONG_DOUBLE_COMPLEX;} 00045 inline MPI::Datatype which_mpi_t (pair<int,int>) {return MPI::TWOINT;} 00046 inline MPI::Datatype which_mpi_t (pair<float,int>) {return MPI::FLOAT_INT;} 00047 inline MPI::Datatype which_mpi_t (pair<double,int>) {return MPI::DOUBLE_INT;} 00048 inline MPI::Datatype which_mpi_t (pair<long double,int>) {return MPI::LONG_DOUBLE_INT;} 00049 inline MPI::Datatype which_mpi_t (pair<short,int>) {return MPI::SHORT_INT;} 00050 00052 const MPI::Datatype mpi_size_t= which_mpi_t (size_t()); 00053 00054 template <typename Base_t> 00055 class buffer_t { 00056 public: 00057 typedef Base_t base_t; 00058 MPI::Datatype mpi_t; 00059 const int base_t_size; 00060 private: 00061 vector<Base_t> my_buffer; // Vector, to use reallocation and copy 00062 mutable size_t write_pos, // Written entries (of Base_t) 00063 read_pos; // Read entries (of Base_t) 00064 mutable bool read_write; // Either read or write, true == read 00065 public: 00067 buffer_t () : mpi_t (which_mpi_t (Base_t())), 00068 base_t_size (sizeof (Base_t)), 00069 write_pos (0), read_pos (0), read_write (false) {} 00071 buffer_t (size_t n) : mpi_t (which_mpi_t (Base_t())), 00072 base_t_size (sizeof (Base_t)), 00073 write_pos (0), read_pos (0), read_write (false) {reserve (n);} 00074 00076 void reserve (size_t n) { 00077 if (read_write) { // change to write 00078 write_pos= 0; my_buffer.resize (n); read_write= false; return; } 00079 if (write_pos+n > my_buffer.size()) my_buffer.resize (write_pos+n);} 00080 00082 bool empty () const { 00083 return read_pos == write_pos; } 00084 00086 size_t remaining () const { 00087 return write_pos - read_pos; } 00088 00090 Base_t read () const { 00091 if (!read_write) { // change to read 00092 read_pos= 0; read_write= true; } 00093 // if (read_pos >= write_pos) throw gmpi_exception ("Read past end"); 00094 return my_buffer[read_pos++]; } 00095 00097 void unread () const { 00098 // if (!read_write) throw gmpi_exception ("Unread in write mode"); 00099 // if (read_pos == 0) throw gmpi_exception ("Unread at the beginning"); 00100 read_pos--; } 00101 00103 void write (Base_t output) { 00104 if (read_write) { // change to write 00105 write_pos= 0; my_buffer.resize (1); read_write= false; } 00106 // if (write_pos > my_buffer.size()) throw gmpi_exception ("Written past end"); 00107 reserve (1); 00108 my_buffer[write_pos++]= output; } 00109 00111 void load (Base_t* address, size_t n) const { 00112 if (!read_write) { // change to read 00113 read_pos= 0; read_write= true; } 00114 // if (read_pos+n > write_pos) throw gmpi_exception ("Read past end"); 00115 memcpy (address, &my_buffer[read_pos], n * base_t_size); 00116 read_pos+= n; } 00117 00119 void store (Base_t* address, size_t n) { 00120 if (read_write) { // change to write 00121 write_pos= 0; my_buffer.resize (n); read_write= false; } 00122 reserve (write_pos+n - my_buffer.size()); 00123 memcpy (&my_buffer[write_pos], address, n * base_t_size); 00124 write_pos+= n; } 00125 00127 size_t size () const {return my_buffer.size(); } 00128 00130 Base_t* address () {return &my_buffer[0]; } 00131 00133 const Base_t* buffer_address () const {return &my_buffer[0]; } 00134 00136 void free () { 00137 my_buffer.resize (0); write_pos= read_pos= 0; read_write= false; } 00138 }; 00139 00141 template <typename Base_t> inline 00142 const buffer_t<Base_t>& operator>> (const buffer_t<Base_t>& buffer, Base_t& input) { 00143 input= buffer.read (); return buffer; } 00144 00146 template <typename Base_t> inline 00147 buffer_t<Base_t>& operator<< (buffer_t<Base_t>& buffer, const Base_t& output) { 00148 buffer.write (output); return buffer; } 00149 00151 template <typename Base_t> inline 00152 const buffer_t<Base_t>& operator>> (const buffer_t<Base_t>& buffer, 00153 vector<Base_t>& input) { 00154 int csize= input.size(), n= buffer.remaining(); 00155 input.resize (csize + n); 00156 buffer.load (&input[csize], n); 00157 return buffer; } 00158 00160 template <typename Base_t> inline 00161 const buffer_t<Base_t>& operator<< (buffer_t<Base_t>& buffer, 00162 const vector<Base_t>& output) { 00163 int n= output.size(); 00164 buffer.store (&output[0], n); 00165 return buffer; } 00166 00167 // =========== Derived Operators ============================================ 00168 00170 template <typename Base_t, typename Scalar1_t, typename Scalar2_t> inline 00171 const buffer_t<Base_t>& operator>> (const buffer_t<Base_t>& buffer, 00172 pair<Scalar1_t, Scalar2_t>& input) { 00173 buffer >> input.first >> input.second; return buffer; } 00174 00176 template <typename Base_t, typename Scalar1_t, typename Scalar2_t> inline 00177 buffer_t<Base_t>& operator<< (buffer_t<Base_t>& buffer, 00178 const pair<Scalar1_t, Scalar2_t>& output) { 00179 buffer << output.first << output.second; return buffer; } 00180 00182 template <typename Base_t, typename Scalar_t> inline 00183 const buffer_t<Base_t>& operator>> (const buffer_t<Base_t>& buffer, 00184 vector<Scalar_t>& input) { 00185 size_t size; buffer >> size; 00186 for (size_t c= 0; c < size; c++) { 00187 Scalar_t scalar; buffer >> scalar; input.push_back (scalar); } 00188 return buffer; } 00189 00191 template <typename Base_t, typename Scalar_t> inline 00192 buffer_t<Base_t>& operator<< (buffer_t<Base_t>& buffer, 00193 const vector<Scalar_t>& output) { 00194 buffer << output.size(); // Base_t must be large enough for the vector size 00195 for (typename vector<Scalar_t>::const_iterator it= output.begin(), end= output.end(); 00196 it != end;) buffer << *it++; 00197 return buffer; } 00198 00200 template <typename Base_t, typename Scalar_t> inline 00201 const buffer_t<Base_t>& operator>> (const buffer_t<Base_t>& buffer, 00202 list<Scalar_t>& input) { 00203 size_t size; buffer >> size; 00204 for (size_t c= 0; c < size; c++) { 00205 Scalar_t scalar; buffer >> scalar; input.push_back (scalar); } 00206 return buffer; } 00207 00209 template <typename Base_t, typename Scalar_t> inline 00210 buffer_t<Base_t>& operator<< (buffer_t<Base_t>& buffer, 00211 const list<Scalar_t>& output) { 00212 buffer << output.size(); // Base_t must be large enough for the list size 00213 for (typename list<Scalar_t>::const_iterator it= output.begin(), end= output.end(); it != end;) 00214 buffer << *it++; 00215 return buffer; } 00216 00218 template <typename Base_t, typename Scalar_t> inline 00219 const buffer_t<Base_t>& operator>> (const buffer_t<Base_t>& buffer, 00220 deque<Scalar_t>& input) { 00221 size_t size; buffer >> size; 00222 for (size_t c= 0; c < size; c++) { 00223 Scalar_t scalar; buffer >> scalar; input.push_back (scalar); } 00224 return buffer; } 00225 00227 template <typename Base_t, typename Scalar_t> inline 00228 buffer_t<Base_t>& operator<< (buffer_t<Base_t>& buffer, 00229 const deque<Scalar_t>& output) { 00230 buffer << output.size(); // Base_t must be large enough for the deque size 00231 for (typename deque<Scalar_t>::const_iterator it= output.begin(), end= output.end(); 00232 it != end;) buffer << *it++; 00233 return buffer; } 00234 00235 template <typename Base_t, typename Object_t> 00236 class comm_ref_t { 00237 public: 00238 typedef Base_t base_t; 00239 typedef Object_t object_t; 00240 private: 00241 object_t& my_object_ref; 00242 public: 00243 comm_ref_t (object_t& o) : my_object_ref (o) {} 00244 object_t& object_ref() {return my_object_ref; } 00245 const object_t& object_ref() const {return my_object_ref; } 00246 }; 00247 00248 class Comm { 00249 protected: 00250 MPI::Intracomm my_comm; 00251 public: 00252 Comm (const MPI::Intracomm& mpi_comm) : my_comm (mpi_comm.Dup()) {} 00253 00254 MPI::Intracomm& mpi_comm_ref () {return my_comm;} 00255 const MPI::Intracomm& mpi_comm_ref () const {return my_comm;} 00256 00258 template <typename Comm_ref_t> 00259 void Send (const Comm_ref_t& data, int dest, int tag) const; 00260 00261 template <typename Comm_ref_t> 00262 void Recv (Comm_ref_t& data, int source, int tag, MPI::Status& status) const; 00263 00264 int Get_size () const {return my_comm.Get_size (); } 00265 00266 int Get_rank () const {return my_comm.Get_rank (); } 00267 }; 00268 00269 class Intracomm : public Comm { 00270 public: 00271 Intracomm (MPI::Intracomm& mpi_comm) : Comm (mpi_comm) {} 00272 00273 void Barrier () const {my_comm.Barrier (); } 00274 00275 template <typename Comm_ref_t> 00276 void Bcast (Comm_ref_t& data, int root) const; 00277 00278 void Bcast (void* buffer, int count, const MPI::Datatype& datatype, int root) const { 00279 my_comm.Bcast (buffer, count, datatype, root); } 00280 00281 template <typename Comm_ref_t> 00282 void Reduce (const Comm_ref_t& senddata, Comm_ref_t& recvdata, 00283 const MPI::Op& op, int root) const; 00284 00285 template <typename Comm_ref_t> 00286 void Allreduce (const Comm_ref_t& senddata, Comm_ref_t& recvdata, 00287 const MPI::Op& op) const; 00288 00289 // template <typename Base_t> 00290 // void Allreduce (const Base_t& senddata, Base_t& recvdata, const MPI::Op& op) const { 00291 // my_comm.Allreduce (&senddata, &recvdata, 1, which_mpi_t (senddata), op); } 00292 00293 void Allreduce (const void* sendbuf, void* recvbuf, int count, 00294 const MPI::Datatype& datatype, const MPI::Op& op) const { 00295 my_comm.Allreduce (sendbuf, recvbuf, count, datatype, op); } 00296 00297 void Gather (const void* sendbuf, int sendcount, const MPI::Datatype& sendtype, 00298 void* recvbuf, int recvcount, const MPI::Datatype& recvtype, int root) const { 00299 my_comm.Gather (sendbuf, sendcount, sendtype, recvbuf, recvcount, recvtype, root); } 00300 }; 00301 00302 } // namespace GMPI 00303 00304 #include "gmpi_impl.hpp" // long template implementations 00305 00306 #endif // _gmpi_include_ 00307 00308 00309 00310 00311 00312