toon-members
[Top][All Lists]
Advanced

[Date Prev][Date Next][Thread Prev][Thread Next][Date Index][Thread Index]

[Toon-members] TooN internal/operators.hh test/mmult_test.cc


From: Tom Drummond
Subject: [Toon-members] TooN internal/operators.hh test/mmult_test.cc
Date: Mon, 09 Mar 2009 11:55:26 +0000

CVSROOT:        /cvsroot/toon
Module name:    TooN
Changes by:     Tom Drummond <twd20>    09/03/09 11:55:26

Modified files:
        internal       : operators.hh 
        test           : mmult_test.cc 

Log message:
        Added Matrix * Vector
        Vector * Matrix (distinct to handle non commuting data types)
        and cleaned up some of the template usage for Matrix*Matrix

CVSWeb URLs:
http://cvs.savannah.gnu.org/viewcvs/TooN/internal/operators.hh?cvsroot=toon&r1=1.10&r2=1.11
http://cvs.savannah.gnu.org/viewcvs/TooN/test/mmult_test.cc?cvsroot=toon&r1=1.1&r2=1.2

Patches:
Index: internal/operators.hh
===================================================================
RCS file: /cvsroot/toon/TooN/internal/operators.hh,v
retrieving revision 1.10
retrieving revision 1.11
diff -u -b -r1.10 -r1.11
--- internal/operators.hh       6 Mar 2009 12:42:35 -0000       1.10
+++ internal/operators.hh       9 Mar 2009 11:55:25 -0000       1.11
@@ -74,9 +74,9 @@
        };
 
        //FIXME what about BLAS?
-       template<typename Precision> struct MatrixMultiply
+       struct MatrixMultiply
        {
-               template<int R, int C, typename B, int R1, int C1, typename P1, 
typename B1, int R2, int C2, typename P2, typename B2> 
+               template<int R, int C, typename Precision, typename B, int R1, 
int C1, typename P1, typename B1, int R2, int C2, typename P2, typename B2> 
                static void eval(Matrix<R, C, Precision, B>& res, const 
Matrix<R1, C1, P1, B1>& m1, const Matrix<R2, C2, P2, B2>& m2)
                {
                        for(int i=0; i < res.num_rows(); ++i)
@@ -85,6 +85,30 @@
                }
        };
 
+       struct MatrixVectorMultiply
+       {
+               template<int Sout, typename Pout, typename Bout, int R, int C, 
int Size, typename P1, typename P2, typename B1, typename B2>
+               static void eval(Vector<Sout, Pout, Bout>& res, const Matrix<R, 
C, P1, B1>& m, const Vector<Size, P2, B2>& v)
+               {
+                       for(int i=0; i < res.size(); ++i){
+                               res[i] = m[i] * v;
+                       }
+               }
+       };
+
+       // this is distinct to cater for non communing precision types
+       struct VectorMatrixMultiply
+       {
+               template<int Sout, typename Pout, typename Bout, int R, int C, 
int Size, typename P1, typename P2, typename B1, typename B2>
+               static void eval(Vector<Sout, Pout, Bout>& res, const 
Vector<Size, P2, B2>& v, const Matrix<R, C, P1, B1>& m)
+               {
+                       for(int i=0; i < res.size(); ++i){
+                               res[i] = v * m[i];
+                       }
+               }
+       };
+
+
        //Mini operators for passing to Pairwise, etc
        struct Add{ template<class A, class B, class C>      static A op(const 
B& b, const C& c){return b+c;} };
        struct Subtract{ template<class A, class B, class C> static A op(const 
B& b, const C& c){return b-c;} };
@@ -164,13 +188,31 @@
 // Matrix multiplication Matrix * Matrix
 
 template<int R1, int C1, int R2, int C2, typename P1, typename P2, typename 
B1, typename B2> 
-Matrix<Internal::Sizer<R1,R1>::size, Internal::Sizer<C2,C2>::size, typename 
Internal::MultiplyType<P1, P2>::type> operator*(const Matrix<R1, C1, P1, B1>& 
m1, const Matrix<R2, C2, P2, B2>& m2)
+Matrix<R1, C2, typename Internal::MultiplyType<P1, P2>::type> operator*(const 
Matrix<R1, C1, P1, B1>& m1, const Matrix<R2, C2, P2, B2>& m2)
 {
        typedef typename Internal::MultiplyType<P1, P2>::type restype;
 
        SizeMismatch<R1, C2>:: test(m1.num_rows(),m2.num_cols());
        SizeMismatch<C1, R2>:: test(m1.num_cols(),m2.num_rows());
-       return Matrix<Internal::Sizer<R1,R1>::size, 
Internal::Sizer<C2,C2>::size,restype>(m1, m2, 
Operator<Internal::MatrixMultiply<restype> >(), m1.num_rows(), m2.num_cols());
+       return Matrix<Internal::Sizer<R1,R1>::size, 
Internal::Sizer<C2,C2>::size,restype>(m1, m2, 
Operator<Internal::MatrixMultiply>(), m1.num_rows(), m2.num_cols());
+}
+
+// Matrix Vector multiplication Matrix * Vector
+
+template<int R, int C, int Size, typename P1, typename P2, typename B1, 
typename B2>
+Vector<R, typename Internal::MultiplyType<P1,P2>::type> operator*(const 
Matrix<R, C, P1, B1>& m, const Vector<Size, P2, B2>& v)
+{
+       SizeMismatch<C,Size>::test(m.num_cols(), v.size());
+       return Vector<R, typename Internal::MultiplyType<P1,P2>::type> (m, v, 
Operator<Internal::MatrixVectorMultiply>(), m.num_rows() );
+}
+                                                                               
                                                        
+// Vector Matrix multiplication Vector * Matrix
+
+template<int Size, int R, int C, typename P1, typename P2, typename B1, 
typename B2>
+Vector<C, typename Internal::MultiplyType<P1,P2>::type> operator*(const 
Vector<Size, P1, B1>& v, const Matrix<R, C, P2, B2>& m)
+{
+       SizeMismatch<R,Size>::test(m.num_rows(), v.size());
+       return Vector<C, typename Internal::MultiplyType<P1,P2>::type> (v, m, 
Operator<Internal::VectorMatrixMultiply>(), m.num_cols() );
 }
 
 

Index: test/mmult_test.cc
===================================================================
RCS file: /cvsroot/toon/TooN/test/mmult_test.cc,v
retrieving revision 1.1
retrieving revision 1.2
diff -u -b -r1.1 -r1.2
--- test/mmult_test.cc  27 Feb 2009 09:45:47 -0000      1.1
+++ test/mmult_test.cc  9 Mar 2009 11:55:26 -0000       1.2
@@ -24,11 +24,20 @@
        m4[1] = makeVector(8, 9);
        m4[2] = makeVector(10, 11);
 
+       Vector<V(a,3)> v(3);
+       v = makeVector(6,8,10);
+
        cout << m3<<endl;
        cout << m4<<endl;
        cout << m3*m4;
        
        cout << "\n should be: \n    28    31\n  100   112\n";
+
+       cout << endl << v << endl;
+       cout << endl << m3*v << endl;
+
+       cout << "\n should be: \n    28    100\n" << endl;
+
 }
 
 int main()




reply via email to

[Prev in Thread] Current Thread [Next in Thread]