Actual source code: aijkokkosimpl.hpp

  1: #if !defined(SEQAIJKOKKOSIMPL_HPP)
  2: #define SEQAIJKOKKOSIMPL_HPP

  4: #include <petscaijdevice.h>
  5: #include <petsc/private/vecimpl_kokkos.hpp>
  6: #include <KokkosSparse_CrsMatrix.hpp>
  7: #include <KokkosSparse_spiluk.hpp>

  9: /*
 10:    Kokkos::View<struct _n_SplitCSRMat,DefaultMemorySpace> is not handled correctly so we define SplitCSRMat
 11:    for the singular purpose of working around this.
 12: */
 13: typedef struct _n_SplitCSRMat SplitCSRMat;

 15: using MatRowOffsetType    = PetscInt;
 16: using MatColumnIndexType  = PetscInt;
 17: using MatValueType        = PetscScalar;

 19: template<class MemorySpace> using KokkosCsrMatrixType   = typename KokkosSparse::CrsMatrix<MatValueType,MatColumnIndexType,MemorySpace,void/* MemoryTraits */,MatRowOffsetType>;
 20: template<class MemorySpace> using KokkosCsrGraphType    = typename KokkosCsrMatrixType<MemorySpace>::staticcrsgraph_type;

 22: using KokkosCsrGraph                      = KokkosCsrGraphType<DefaultMemorySpace>;
 23: using KokkosCsrMatrix                     = KokkosCsrMatrixType<DefaultMemorySpace>;

 25: using KokkosCsrGraphHost                  = KokkosCsrGraphType<DefaultMemorySpace>::HostMirror;

 27: using ConstMatColumnIndexKokkosView       = KokkosCsrGraph::entries_type;
 28: using ConstMatRowOffsetKokkosView         = KokkosCsrGraph::row_map_type;
 29: using ConstMatValueKokkosView             = KokkosCsrMatrix::values_type;

 31: using MatColumnIndexKokkosView            = KokkosCsrGraph::entries_type::non_const_type;
 32: using MatRowOffsetKokkosView              = KokkosCsrGraph::row_map_type::non_const_type;
 33: using MatValueKokkosView                  = KokkosCsrMatrix::values_type::non_const_type;

 35: using MatColumnIndexKokkosViewHost        = MatColumnIndexKokkosView::HostMirror;
 36: using MatRowOffsetKokkosViewHost          = MatRowOffsetKokkosView::HostMirror;
 37: using MatValueKokkosViewHost              = MatValueKokkosView::HostMirror;

 39: using MatValueKokkosDualView              = Kokkos::DualView<MatValueType*>;

 41: using KernelHandle                        = KokkosKernels::Experimental::KokkosKernelsHandle<MatRowOffsetType,MatColumnIndexType,MatValueType,DefaultExecutionSpace,DefaultMemorySpace,DefaultMemorySpace>;

 43: struct Mat_SeqAIJKokkosTriFactors {
 44:   MatRowOffsetKokkosView         iL_d,iU_d,iLt_d,iUt_d; /* rowmap for L, U, L^t, U^t of A=LU */
 45:   MatColumnIndexKokkosView       jL_d,jU_d,jLt_d,jUt_d; /* column ids */
 46:   MatValueKokkosView             aL_d,aU_d,aLt_d,aUt_d; /* matrix values */
 47:   KernelHandle                   kh,khL,khU,khLt,khUt;  /* Kernel handles for A, L, U, L^t, U^t */
 48:   PetscBool                      transpose_updated;     /* Are L^T, U^T updated wrt L, U*/
 49:   PetscBool                      sptrsv_symbolic_completed; /* Have we completed the symbolic solve for L and U */
 50:   PetscScalarKokkosView          workVector;

 52:   Mat_SeqAIJKokkosTriFactors(PetscInt n)
 53:     : transpose_updated(PETSC_FALSE),sptrsv_symbolic_completed(PETSC_FALSE),workVector("workVector",n) {}

 55:   ~Mat_SeqAIJKokkosTriFactors() {Destroy();}

 57:   void Destroy() {
 58:     kh.destroy_spiluk_handle();
 59:     khL.destroy_sptrsv_handle();
 60:     khU.destroy_sptrsv_handle();
 61:     khLt.destroy_sptrsv_handle();
 62:     khUt.destroy_sptrsv_handle();
 63:     transpose_updated = sptrsv_symbolic_completed = PETSC_FALSE;
 64:   }
 65: };

 67: struct Mat_SeqAIJKokkos {
 68:   MatRowOffsetKokkosViewHost     i_h;
 69:   MatRowOffsetKokkosView         i_d;

 71:   MatColumnIndexKokkosViewHost   j_h;
 72:   MatColumnIndexKokkosView       j_d;

 74:   MatValueKokkosViewHost         a_h;
 75:   MatValueKokkosView             a_d;

 77:   MatValueKokkosDualView         a_dual;

 79:   KokkosCsrGraphHost             csrgraph_h;
 80:   KokkosCsrGraph                 csrgraph_d;

 82:   KokkosCsrMatrix                csrmat; /* The CSR matrix */
 83:   PetscObjectState               nonzerostate; /* State of the nonzero pattern (graph) on device */

 85:   Mat                            At,Ah; /* Transpose and Hermitian of the matrix in MATAIJKOKKOS type (built on demand) */
 86:   PetscBool                      transpose_updated,hermitian_updated; /* Are At, Ah updated wrt the matrix? */

 88:   Kokkos::View<PetscInt*>        *i_uncompressed_d;
 89:   Kokkos::View<PetscInt*>        *colmap_d; // ugh, this is a parallel construct
 90:   Kokkos::View<SplitCSRMat,DefaultMemorySpace> device_mat_d;
 91:   Kokkos::View<PetscInt*>        *diag_d; // factorizations

 93:    /* Construct a nrows by ncols matrix of nnz nonzeros with (i,j,a) for the CSR */
 94:   Mat_SeqAIJKokkos(MatColumnIndexType nrows,MatColumnIndexType ncols,MatRowOffsetType nnz,MatRowOffsetType *i,MatColumnIndexType *j,MatValueType *a)
 95:    : i_h(i,nrows+1),j_h(j,nnz),a_h(a,nnz),At(NULL),Ah(NULL),transpose_updated(PETSC_FALSE),hermitian_updated(PETSC_FALSE),
 96:      i_uncompressed_d(NULL),colmap_d(NULL),device_mat_d(NULL),diag_d(NULL)
 97:   {
 98:      i_d        = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(),i_h);
 99:      j_d        = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(),j_h);
100:      a_d        = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(),a_h);
101:      csrgraph_d = KokkosCsrGraph(j_d,i_d);
102:      csrgraph_h = KokkosCsrGraphHost(j_h,i_h);
103:      a_dual     = MatValueKokkosDualView(a_d,a_h);
104:      csrmat     = KokkosCsrMatrix("csrmat",ncols,a_d,csrgraph_d);
105:   }

107:   ~Mat_SeqAIJKokkos()
108:   {
109:     DestroyMatTranspose();
110:   }

112:   PetscErrorCode DestroyMatTranspose(void)
113:   {
116:     MatDestroy(&At);
117:     MatDestroy(&Ah);
118:     return(0);
119:   }
120: };

122: #endif