getfem-commits
[Top][All Lists]
Advanced

[Date Prev][Date Next][Thread Prev][Thread Next][Date Index][Thread Index]

[Getfem-commits] (no subject)


From: Yves Renard
Subject: [Getfem-commits] (no subject)
Date: Thu, 10 May 2018 03:20:43 -0400 (EDT)

branch: devel-yves-generic-assembly-modifs
commit 8a7be52188d65f00fe3eac825d9c3364a0ec65eb
Author: Yves Renard <address@hidden>
Date:   Tue May 8 15:39:09 2018 +0200

    grad(expression), work in progress
---
 interface/tests/python/check_asm.py     | 32 +++++++++--
 src/getfem_generic_assembly_semantic.cc | 99 +++++++++++++++++++++++++--------
 src/getfem_generic_assembly_tree.cc     |  3 +-
 3 files changed, 102 insertions(+), 32 deletions(-)

diff --git a/interface/tests/python/check_asm.py 
b/interface/tests/python/check_asm.py
index 3e0eb79..85cd615 100644
--- a/interface/tests/python/check_asm.py
+++ b/interface/tests/python/check_asm.py
@@ -148,11 +148,31 @@ res = gf.asm('expression analysis', "Grad(Grad_w/u)",  
mim, 0, md)
 if (res != "((Hess_w/u)-((Grad_w/sqr(u))@Grad_u))"):
   print "Bad gradient"; exit(1)
   
-print 'Assembly string "Grad([u; 1; u])" gives:'
-res = gf.asm('expression analysis', "Grad([u; 1; u])",  mim, 0, md)
+# To be controlled after fixing C_MATRIX
+print 'Assembly string "Grad([u,u; 2,1; u,u])" gives:'
+res = gf.asm('expression analysis', "Grad([u,u; 2,1; u,u])",  mim, 0, md)
 
-print 'Assembly string "Grad([u, 1, u])" gives:'
-res = gf.asm('expression analysis', "Grad([u, 1, u])",  mim, 0, md)
 
-print 'Assembly string "[[Grad_u(1),Grad_u(2)],[0,0],[Grad_u(1),Grad_u(2)]]" 
gives:'
-res = gf.asm('expression analysis', 
"[[Grad_u(1),Grad_u(2)],[0,0],[Grad_u(1),Grad_u(2)]]",  mim, 0, md)
+# To be controlled after fixing C_MATRIX
+print 'Assembly string "Grad([[u,2,u],[u,1,u]])" gives:'
+res = gf.asm('expression analysis', "Grad([[u,2,u],[u,1,u]])",  mim, 0, md)
+
+
+
+print 'Assembly string "Grad(Reshape(Grad_w, 1, 4))" gives:'
+res = gf.asm('expression analysis', "Grad(Reshape(Grad_w, 1, 4))",  mim, 0, md)
+if (res != "(Reshape(Hess_w, 1, 4, 2))"): print "Bad gradient"; exit(1)
+
+print 'Assembly string "Grad(Grad_w(1,2))" gives:'
+res = gf.asm('expression analysis', "Grad(Grad_w(1,2))",  mim, 0, md)
+if (res != "(Hess_w(1, 2, :))"): print "Bad gradient"; exit(1)
+
+print 'Assembly string "Grad(Index_move_last(Grad_w, 1))" gives:'
+res = gf.asm('expression analysis', "Grad(Index_move_last(Grad_w, 1))", mim, 
0, md)
+if (res != "(Swap_indices(Index_move_last(Hess_w, 1), 2, 3))"):
+  print "Bad gradient"; exit(1)
+
+print 'Assembly string "Grad(Contract(Grad_w, 1, 2, Grad_w, 1, 2))" gives:'
+res = gf.asm('expression analysis', "Grad(Contract(Grad_w, 1, 2, Grad_w, 1, 
2))", mim, 0, md)
+if (res != "(Contract(Hess_w, 1, 2, Grad_w, 1, 2)+Contract(Grad_w, 1, 2, 
Hess_w, 1, 2))"):
+  print "Bad gradient"; exit(1)
diff --git a/src/getfem_generic_assembly_semantic.cc 
b/src/getfem_generic_assembly_semantic.cc
index 7e65f2c..983a70f 100644
--- a/src/getfem_generic_assembly_semantic.cc
+++ b/src/getfem_generic_assembly_semantic.cc
@@ -1395,16 +1395,16 @@ namespace getfem {
             if (nbc1 == 1 && nbc2 == 1 && nbc3 == 1)
               for (size_type i = 0; i < nbl; ++i)
                 pnode->tensor()[i] = pnode->children[i]->tensor()[0];
-            else if (nbc2 == 1 && nbc3 == 1) // TODO: verify order
+            else if (nbc2 == 1 && nbc3 == 1)
               for (size_type i = 0; i < nbl; ++i)
                 for (size_type j = 0; j < nbc1; ++j)
                   pnode->tensor()(i,j) = pnode->children[n++]->tensor()[0];
-            else if (nbc3 == 1) // TODO: verify order
+            else if (nbc3 == 1)
               for (size_type i = 0; i < nbl; ++i)
                 for (size_type j = 0; j < nbc2; ++j)
                   for (size_type k = 0; k < nbc1; ++k)
                     pnode->tensor()(i,j,k) = pnode->children[n++]->tensor()[0];
-            else // TODO: verify order
+            else
               for (size_type i = 0; i < nbl; ++i)
                 for (size_type j = 0; j < nbc3; ++j)
                   for (size_type k = 0; k < nbc2; ++k)
@@ -4209,15 +4209,25 @@ namespace getfem {
          }
        }
        if (m.dim() > 1) {
-         cout << "mi = " << pnode->tensor().sizes() << " : " <<  
pnode->tensor_order() << endl;
-         mi = pnode->tensor().sizes(); mi.push_back(m.dim());
-         cout << "mi = " << mi << endl;
-         pnode->t.adjust_sizes(mi);
-         size_type orgsize = pnode->children.size();
-         pnode->children.resize(pnode->tensor_proper_size(), nullptr);
-         for (size_type i = orgsize; i < pnode->children.size(); ++i) {
-           tree.copy_node(pnode->children[i-orgsize], pnode,
-                          pnode->children[i]);
+         size_type nbl = pnode->children.size() /
+           (pnode->nbc1*pnode->nbc2*pnode->nbc3);
+         if (pnode->nbc1==1 && pnode->nbc2==1 && pnode->nbc3==1)
+           pnode->nbc1 = m.dim();
+         else if (pnode->nbc2==1 && pnode->nbc3==1)
+           { pnode->nbc2 = pnode->nbc1; pnode->nbc1 = m.dim(); }
+         else if (pnode->nbc3==1)
+           { pnode->nbc3 = pnode->nbc2; pnode->nbc2 = pnode->nbc1;
+             pnode->nbc1 = m.dim(); }
+         else GMM_ASSERT1(false, "Sorry this exceed the current limit of "
+                          "constant tensors (limited to order four)");
+         pnode->children.resize(pnode->nbc1*pnode->nbc2*pnode->nbc3*nbl,
+                                nullptr);
+         for (size_type i = pnode->children.size()-1; i > 0; --i) {
+           if (i % m.dim())
+             tree.copy_node(pnode->children[i/m.dim()], pnode,
+                            pnode->children[i]);
+           else
+             std::swap(pnode->children[i/m.dim()], pnode->children[i]);
          }
          for (size_type i = 0; i < pnode->children.size(); ++i) {
            pga_tree_node child = pnode->children[i];
@@ -4225,24 +4235,66 @@ namespace getfem {
              tree.insert_node(child, GA_NODE_PARAMS);
              tree.add_child(child->parent, GA_NODE_CONSTANT);
              child->parent->children[1]
-               ->init_scalar_tensor(scalar_type(1+i/orgsize));
+               ->init_scalar_tensor(scalar_type(1+i%m.dim()));
            }
          }
        }
       }
       break;
 
-#ifdef continue_here
     case GA_NODE_PARAMS:
       if (child0->node_type == GA_NODE_RESHAPE) {
-        ga_node_grad(tree, workspace, m, pnode->children[1],
-                           varname, interpolatename, order);
-      }        else if (child0->node_type == GA_NODE_IND_MOVE_LAST) {
-        // TODO !!!!
+        ga_node_grad(tree, workspace, m, pnode->children[1]);
+       tree.add_child(pnode, GA_NODE_CONSTANT);
+       pnode->children.back()->init_scalar_tensor(scalar_type(m.dim()));
+      } else if (child0->node_type == GA_NODE_IND_MOVE_LAST) {
+       size_type order = pnode->tensor_order();
+       ga_node_grad(tree, workspace, m, pnode->children[1]);
+       tree.insert_node(pnode, GA_NODE_PARAMS);
+       tree.add_child(pnode->parent); tree.add_child(pnode->parent);
+       tree.add_child(pnode->parent);
+       std::swap(pnode->parent->children[0], pnode->parent->children[1]);
+       pnode->parent->children[0]->node_type = GA_NODE_SWAP_IND;
+       pnode->parent->children[2]->node_type = GA_NODE_CONSTANT;
+       pnode->parent->children[3]->node_type = GA_NODE_CONSTANT;
+       pnode->parent->children[2]->init_scalar_tensor(scalar_type(order));
+       pnode->parent->children[3]->init_scalar_tensor(scalar_type(order+1));
       }        else if (child0->node_type == GA_NODE_SWAP_IND) {
-        // TODO !!!!
+        ga_node_grad(tree, workspace, m, pnode->children[1]);
       }        else if (child0->node_type == GA_NODE_CONTRACT) {
-        // TODO !!!! (avec mark1 et "child2"->marked
+       mark0 = mark1;
+       size_type ch2 = 0;
+       if (pnode->children.size() == 5) ch2 = 3;
+       if (pnode->children.size() == 7) ch2 = 4;
+       mark1 = pnode->children[ch2]->marked;
+         
+       if (pnode->children.size() == 4) {
+         ga_node_grad(tree, workspace, m, pnode->children[1]);
+       } else {
+         pga_tree_node pg1(pnode), pg2(pnode);
+         if (mark0 && mark1) {
+           tree.duplicate_with_addition(pnode);
+           pg2 = pnode->parent->children[1];
+         }
+         if (mark0) {
+           size_type nred = pg1->children[1]->tensor_order();
+           if (pnode->children.size() == 7) nred--;
+           ga_node_grad(tree, workspace, m, pg1->children[1]);
+           tree.insert_node(pg1, GA_NODE_PARAMS);
+           tree.add_child(pg1->parent); tree.add_child(pg1->parent);
+           std::swap(pg1->parent->children[0], pg1->parent->children[1]);
+           pg1->parent->children[0]->node_type = GA_NODE_IND_MOVE_LAST;
+           pg1->parent->children[2]->node_type = GA_NODE_CONSTANT;
+           pg1->parent->children[2]->init_scalar_tensor(scalar_type(nred));
+         }
+         if (mark1) {
+           ga_node_grad(tree, workspace, m, pg2->children[ch2]);
+         }
+         ga_print_node(pg1, cout); cout << endl;
+         ga_print_node(pg2, cout); cout << endl;
+       }
+#ifdef continue_here
+
       } else if (child0->node_type == GA_NODE_PREDEF_FUNC) {
         std::string name = child0->name;
         ga_predef_function_tab::const_iterator it = 
PREDEF_FUNCTIONS.find(name);
@@ -4442,14 +4494,13 @@ namespace getfem {
 
           }
         }
-
+#endif
       } else {
-        ga_node_derivation(tree, workspace, m, child0, varname,
-                           interpolatename, order);
+        ga_node_grad(tree, workspace, m, child0);
+       tree.add_child(pnode, GA_NODE_ALLINDICES);
       }
       break;
 
-#endif
 
     default: GMM_ASSERT1(false, "Unexpected node type " << pnode->node_type
                          << " in derivation. Internal error.");
diff --git a/src/getfem_generic_assembly_tree.cc 
b/src/getfem_generic_assembly_tree.cc
index 7febcb2..7e0320d 100644
--- a/src/getfem_generic_assembly_tree.cc
+++ b/src/getfem_generic_assembly_tree.cc
@@ -1062,8 +1062,7 @@ namespace getfem {
     case GA_NODE_C_MATRIX:
       {
         GMM_ASSERT1(pnode->children.size(), "Invalid tree");
-        size_type nbc1 = pnode->nbc1;
-        size_type nbc2 = pnode->nbc2;
+        size_type nbc1 = pnode->nbc1, nbc2 = pnode->nbc2;
         size_type nbc3 = pnode->nbc3;
         size_type nbcl = pnode->children.size()/(nbc1*nbc2*nbc3);
         if (nbc1 > 1) str << "[";



reply via email to

[Prev in Thread] Current Thread [Next in Thread]