Commit f2419a9b authored by Uliana Alekseeva's avatar Uliana Alekseeva

hsmt_nonsph: change the explicit data copy into the implicit one

parent d883d3c5
...@@ -29,7 +29,7 @@ CONTAINS ...@@ -29,7 +29,7 @@ CONTAINS
INTEGER, INTENT (IN) :: n,isp,iintsp,jintsp INTEGER, INTENT (IN) :: n,isp,iintsp,jintsp
COMPLEX,INTENT(IN) :: chi COMPLEX,INTENT(IN) :: chi
! .. Array Arguments .. ! .. Array Arguments ..
REAL,INTENT(IN) :: fj(:,0:,:),gj(:,0:,:) REAL,INTENT(IN) :: fj(:,0:,:),gj(:,0:,:)
CLASS(t_mat),INTENT(INOUT) ::hmat CLASS(t_mat),INTENT(INOUT) ::hmat
#if defined CPP_GPU #if defined CPP_GPU
REAL, ALLOCATABLE,DEVICE :: fj_dev(:,:,:), gj_dev(:,:,:) REAL, ALLOCATABLE,DEVICE :: fj_dev(:,:,:), gj_dev(:,:,:)
...@@ -90,7 +90,6 @@ CONTAINS ...@@ -90,7 +90,6 @@ CONTAINS
INTEGER:: nn,na,ab_size,l,ll,m INTEGER:: nn,na,ab_size,l,ll,m
real :: rchi real :: rchi
COMPLEX,ALLOCATABLE,DEVICE :: ab1_dev(:,:), ab_dev(:,:), ab2_dev(:,:) COMPLEX,ALLOCATABLE,DEVICE :: ab1_dev(:,:), ab_dev(:,:), ab2_dev(:,:)
COMPLEX,ALLOCATABLE,DEVICE :: c_dev(:,:)
integer :: i, j, istat integer :: i, j, istat
call nvtxStartRange("hsmt_nonsph",1) call nvtxStartRange("hsmt_nonsph",1)
...@@ -105,8 +104,6 @@ CONTAINS ...@@ -105,8 +104,6 @@ CONTAINS
ENDIF ENDIF
hmat%data_c=0.0 hmat%data_c=0.0
ENDIF ENDIF
ALLOCATE(c_dev(SIZE(hmat%data_c,1),SIZE(hmat%data_c,2)))
c_dev = hmat%data_c
DO nn = 1,atoms%neq(n) DO nn = 1,atoms%neq(n)
na = SUM(atoms%neq(:n-1))+nn na = SUM(atoms%neq(:n-1))+nn
...@@ -121,7 +118,7 @@ CONTAINS ...@@ -121,7 +118,7 @@ CONTAINS
!ab1=MATMUL(ab(:lapw%nv(iintsp),:ab_size),td%h_loc(:ab_size,:ab_size,n,isp)) !ab1=MATMUL(ab(:lapw%nv(iintsp),:ab_size),td%h_loc(:ab_size,:ab_size,n,isp))
IF (iintsp==jintsp) THEN IF (iintsp==jintsp) THEN
call nvtxStartRange("zherk",3) call nvtxStartRange("zherk",3)
CALL ZHERK("U","N",lapw%nv(iintsp),ab_size,Rchi,ab1_dev,SIZE(ab1_dev,1),1.0,c_dev,SIZE(c_dev,1)) CALL ZHERK("U","N",lapw%nv(iintsp),ab_size,Rchi,ab1_dev,SIZE(ab1_dev,1),1.0,hmat%data_c,SIZE(hmat%data_c,1))
istat = cudaDeviceSynchronize() istat = cudaDeviceSynchronize()
call nvtxEndRange() call nvtxEndRange()
ELSE !here the l_ss off-diagonal part starts ELSE !here the l_ss off-diagonal part starts
...@@ -138,13 +135,11 @@ CONTAINS ...@@ -138,13 +135,11 @@ CONTAINS
enddo enddo
enddo enddo
CALL zgemm("N","T",lapw%nv(iintsp),lapw%nv(jintsp),ab_size,chi,ab2_dev,SIZE(ab2_dev,1),& CALL zgemm("N","T",lapw%nv(iintsp),lapw%nv(jintsp),ab_size,chi,ab2_dev,SIZE(ab2_dev,1),&
ab1_dev,SIZE(ab1_dev,1),CMPLX(1.0,0.0),c_dev,SIZE(c_dev,1)) ab1_dev,SIZE(ab1_dev,1),CMPLX(1.0,0.0),hmat%data_c,SIZE(hmat%data_c,1))
ENDIF ENDIF
ENDIF ENDIF
END DO END DO
hmat%data_c = c_dev
IF (hmat%l_real) THEN IF (hmat%l_real) THEN
hmat%data_r=hmat%data_r+REAL(hmat%data_c) hmat%data_r=hmat%data_r+REAL(hmat%data_c)
ENDIF ENDIF
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment