Commit dfb6226e authored by Uliana Alekseeva's avatar Uliana Alekseeva

bugfix in hsmt_nonsph.F90 for GPU version

parent a6795b74
......@@ -25,6 +25,8 @@ CONTAINS
USE m_types
USE m_ylm
USE m_apws
USE cudafor
USE nvtx
IMPLICIT NONE
TYPE(t_sym),INTENT(IN) :: sym
TYPE(t_cell),INTENT(IN) :: cell
......@@ -38,7 +40,7 @@ CONTAINS
INTEGER,INTENT(OUT) :: ab_size
! ..
! .. Array Arguments ..
REAL, INTENT(IN) :: fj(:,:,:),gj(:,:,:)
REAL, DEVICE, INTENT(IN) :: fj(:,:,:),gj(:,:,:)
COMPLEX,DEVICE, INTENT (OUT) :: ab(:,:)
!Optional arguments if abc coef for LOs are needed
COMPLEX, INTENT(INOUT),OPTIONAL:: abclo(:,-atoms%llod:,:,:)
......@@ -49,29 +51,28 @@ CONTAINS
COMPLEX,ALLOCATABLE :: ylm(:,:)
COMPLEX,ALLOCATABLE :: c_ph(:,:)
REAL, ALLOCATABLE :: gkrot(:,:)
LOGICAL :: l_apw
COMPLEX:: term
REAL, ALLOCATABLE,DEVICE :: fj_dev(:,:,:), gj_dev(:,:,:)
COMPLEX,ALLOCATABLE,DEVICE :: c_ph_dev(:,:)
COMPLEX,ALLOCATABLE,DEVICE :: ylm_dev(:,:)
REAL, ALLOCATABLE,DEVICE :: gkrot_dev(:,:)
INTEGER :: istat
ALLOCATE(fj_dev(MAXVAL(lapw%nv),atoms%lmaxd+1,MERGE(2,1,noco%l_noco)))
ALLOCATE(gj_dev(MAXVAL(lapw%nv),atoms%lmaxd+1,MERGE(2,1,noco%l_noco)))
call nvtxStartRange("hsmt_ab",2)
lmax=MERGE(atoms%lnonsph(n),atoms%lmax(n),l_nonsph)
ALLOCATE(c_ph_dev(lapw%nv(1),MERGE(2,1,noco%l_ss)))
ALLOCATE(ylm_dev(lapw%nv(1),(atoms%lmaxd+1)**2))
fj_dev(:,:,:)= fj(:,:,:)
gj_dev(:,:,:)= gj(:,:,:)
ALLOCATE(ylm_dev((lmax+1)**2,lapw%nv(1)))
ALLOCATE(gkrot_dev(3,lapw%nv(1)))
ALLOCATE(ylm(lapw%nv(1),(atoms%lmaxd+1)**2))
ALLOCATE(ylm((lmax+1)**2,lapw%nv(1)))
ALLOCATE(c_ph(lapw%nv(1),MERGE(2,1,noco%l_ss)))
ALLOCATE(gkrot(3,lapw%nv(1)))
lmax=MERGE(atoms%lnonsph(n),atoms%lmax(n),l_nonsph)
ab_size=lmax*(lmax+2)+1
l_apw=ALL(gj==0.0)
ab=0.0
np = sym%invtab(atoms%ngopr(na))
......@@ -92,50 +93,64 @@ CONTAINS
END DO
END IF
!--> generate spherical harmonics
gkrot_dev = gkrot
!--> synthesize the complex conjugates of a and b
!!$cuf kernel do <<<*,256>>>
!DO k = 1,lapw%nv(1)
! !--> generate spherical harmonics
! CALL ylm4_dev(lmax,gkrot_dev(:,k),ylm_dev(:,k))
!ENDDO
DO k = 1,lapw%nv(1)
vmult(:) = gkrot(:,k)
CALL ylm4(lmax,vmult,ylm(k,:))
call ylm4(lmax,gkrot(:,k),ylm(:,k))
ENDDO
ylm_dev=ylm
ylm_dev = ylm
!--> synthesize the complex conjugates of a and b
call nvtxStartRange("hsmt_cuf",5)
!$cuf kernel do <<<*,256>>>
DO k = 1,lapw%nv(1)
!--> generate spherical harmonics
!CALL ylm4_dev(lmax,gkrot_dev(:,k),ylm_dev(:,k))
DO l = 0,lmax
ll1 = l* (l+1)
DO m = -l,l
ab(k,ll1+m+1) = CONJG(fj_dev(k,l+1,iintsp)*c_ph_dev(k,iintsp)*ylm_dev(k,ll1+m+1))
ab(k,ll1+m+1+ab_size) = CONJG(gj_dev(k,l+1,iintsp)*c_ph_dev(k,iintsp)*ylm_dev(k,ll1+m+1))
ab(k,ll1+m+1) = CONJG(fj(k,l+1,iintsp)*c_ph_dev(k,iintsp)*ylm_dev(ll1+m+1,k))
ab(k,ll1+m+1+ab_size) = CONJG(gj(k,l+1,iintsp)*c_ph_dev(k,iintsp)*ylm_dev(ll1+m+1,k))
END DO
END DO
ENDDO !k-loop
istat = cudaDeviceSynchronize()
call nvtxEndRange
IF (PRESENT(abclo)) THEN
DO k = 1,lapw%nv(1)
!determine also the abc coeffs for LOs
invsfct=MERGE(1,2,atoms%invsat(na).EQ.0)
term = fpi_const/SQRT(cell%omtil)* ((atoms%rmt(n)**2)/2)*c_ph(k,iintsp)
DO lo = 1,atoms%nlo(n)
l = atoms%llo(lo,n)
DO nkvec=1,invsfct*(2*l+1)
IF (lapw%kvec(nkvec,lo,na)==k) THEN !This k-vector is used in LO
ll1 = l*(l+1) + 1
DO m = -l,l
lm = ll1 + m
abclo(1,m,nkvec,lo) = term*ylm(k,lm)*alo1(lo)
abclo(2,m,nkvec,lo) = term*ylm(k,lm)*blo1(lo)
abclo(3,m,nkvec,lo) = term*ylm(k,lm)*clo1(lo)
END DO
END IF
ENDDO
ENDDO
ENDDO
print*, "Ooooops, TODO in hsmt_ab"
!DO k = 1,lapw%nv(1)
! !determine also the abc coeffs for LOs
! invsfct=MERGE(1,2,atoms%invsat(na).EQ.0)
! term = fpi_const/SQRT(cell%omtil)* ((atoms%rmt(n)**2)/2)*c_ph(k,iintsp)
! DO lo = 1,atoms%nlo(n)
! l = atoms%llo(lo,n)
! DO nkvec=1,invsfct*(2*l+1)
! IF (lapw%kvec(nkvec,lo,na)==k) THEN !This k-vector is used in LO
! ll1 = l*(l+1) + 1
! DO m = -l,l
! lm = ll1 + m
! abclo(1,m,nkvec,lo) = term*ylm(k,lm)*alo1(lo)
! abclo(2,m,nkvec,lo) = term*ylm(k,lm)*blo1(lo)
! abclo(3,m,nkvec,lo) = term*ylm(k,lm)*clo1(lo)
! END DO
! END IF
! ENDDO
! ENDDO
!ENDDO
ENDIF
IF (.NOT.l_apw) ab_size=ab_size*2
ab_size=ab_size*2
call nvtxEndRange
END SUBROUTINE hsmt_ab_gpu
#endif
......
......@@ -70,29 +70,32 @@ CONTAINS
COMPLEX,ALLOCATABLE:: ab(:,:),ab1(:,:),ab2(:,:)
real :: rchi
#ifdef _CUDA
COMPLEX,ALLOCATABLE,DEVICE :: c_dev(:,:), ab1_dev(:,:), ab_dev(:,:)
COMPLEX,ALLOCATABLE,DEVICE :: c_dev(:,:), ab1_dev(:,:), ab_dev(:,:), ab2_dev(:,:)
COMPLEX,ALLOCATABLE,DEVICE :: h_loc_dev(:,:)
!REAL, ALLOCATABLE,DEVICE :: fj_dev(:,:,:), gj_dev(:,:,:)
REAL, ALLOCATABLE,DEVICE :: fj_dev(:,:,:), gj_dev(:,:,:)
integer :: i, j, istat
call nvtxStartRange("hsmt_nonsph",1)
print*, "running CUDA version"
#endif
print *, "nonsph start"
ALLOCATE(ab(MAXVAL(lapw%nv),2*atoms%lmaxd*(atoms%lmaxd+2)+2),ab1(lapw%nv(jintsp),2*atoms%lmaxd*(atoms%lmaxd+2)+2))
#ifdef _CUDA
ALLOCATE(h_loc_dev(size(td%h_loc,1),size(td%h_loc,2)))
ALLOCATE(ab1_dev(size(ab1,1),size(ab1,2)))
ALLOCATE(ab_dev(size(ab,1),size(ab,2)))
h_loc_dev(1:,1:) = CONJG(td%h_loc(0:,0:,n,isp)) !WORKAROUND, var_dev=CONJG(var_dev) does not work (pgi18.4)
!ALLOCATE(fj_dev(MAXVAL(lapw%nv),atoms%lmaxd+1,MERGE(2,1,noco%l_noco)))
!ALLOCATE(gj_dev(MAXVAL(lapw%nv),atoms%lmaxd+1,MERGE(2,1,noco%l_noco)))
!fj_dev(1:,1:,1:)= fj(1:,0:,1:)
!gj_dev(1:,1:,1:)= gj(1:,0:,1:)
h_loc_dev(1:,1:) = CONJG(td%h_loc(0:,0:,n,isp)) !WORKAROUND, var_dev=CONJG(var_dev) does not work
ALLOCATE(fj_dev(MAXVAL(lapw%nv),atoms%lmaxd+1,MERGE(2,1,noco%l_noco)))
ALLOCATE(gj_dev(MAXVAL(lapw%nv),atoms%lmaxd+1,MERGE(2,1,noco%l_noco)))
fj_dev(1:,1:,1:)= fj(1:,0:,1:)
gj_dev(1:,1:,1:)= gj(1:,0:,1:)
!note that basically all matrices in the GPU version are conjugates of their
!cpu counterparts
#endif
IF (iintsp.NE.jintsp) ALLOCATE(ab2(lapw%nv(iintsp),2*atoms%lmaxd*(atoms%lmaxd+2)+2))
#ifdef _CUDA
IF (iintsp.NE.jintsp) ALLOCATE(ab2_dev(lapw%nv(iintsp),2*atoms%lmaxd*(atoms%lmaxd+2)+2))
#endif
IF (hmat%l_real) THEN
IF (ANY(SHAPE(hmat%data_c)/=SHAPE(hmat%data_r))) THEN
......@@ -110,16 +113,15 @@ CONTAINS
na = SUM(atoms%neq(:n-1))+nn
IF ((atoms%invsat(na)==0) .OR. (atoms%invsat(na)==1)) THEN
rchi=MERGE(REAL(chi),REAL(chi)*2,(atoms%invsat(na)==0))
#ifdef _CUDA
CALL hsmt_ab(sym,atoms,noco,isp,jintsp,n,na,cell,lapw,fj,gj,ab_dev,ab_size,.TRUE.)
! istat = cudaDeviceSynchronize()
CALL hsmt_ab(sym,atoms,noco,isp,jintsp,n,na,cell,lapw,fj_dev,gj_dev,ab_dev,ab_size,.TRUE.)
! istat = cudaDeviceSynchronize()
#else
CALL hsmt_ab(sym,atoms,noco,isp,jintsp,n,na,cell,lapw,fj,gj,ab,ab_size,.TRUE.)
#endif
!Calculate Hamiltonian
#ifdef _CUDA
!ab_dev = CONJG(ab)
CALL zgemm("N","N",lapw%nv(jintsp),ab_size,ab_size,CMPLX(1.0,0.0),ab_dev,SIZE(ab_dev,1),h_loc_dev,SIZE(h_loc_dev,1),CMPLX(0.,0.),ab1_dev,SIZE(ab1_dev,1))
#else
CALL zgemm("N","N",lapw%nv(jintsp),ab_size,ab_size,CMPLX(1.0,0.0),ab,SIZE(ab,1),td%h_loc(0:,0:,n,isp),SIZE(td%h_loc,1),CMPLX(0.,0.),ab1,SIZE(ab1,1))
......@@ -127,7 +129,8 @@ CONTAINS
!ab1=MATMUL(ab(:lapw%nv(iintsp),:ab_size),td%h_loc(:ab_size,:ab_size,n,isp))
IF (iintsp==jintsp) THEN
#ifdef _CUDA
call nvtxStartRange("zherk",3)
call nvtxStartRange("zherk",3)
!ab1_dev=CONJG(ab1)
CALL ZHERK("U","N",lapw%nv(iintsp),ab_size,Rchi,ab1_dev,SIZE(ab1_dev,1),1.0,c_dev,SIZE(c_dev,1))
istat = cudaDeviceSynchronize()
call nvtxEndRange()
......@@ -136,10 +139,24 @@ CONTAINS
#endif
ELSE !here the l_ss off-diagonal part starts
!Second set of ab is needed
#ifdef _CUDA
CALL hsmt_ab(sym,atoms,noco,isp,iintsp,n,na,cell,lapw,fj_dev,gj_dev,ab_dev,ab_size,.TRUE.)
#else
CALL hsmt_ab(sym,atoms,noco,isp,iintsp,n,na,cell,lapw,fj,gj,ab,ab_size,.TRUE.)
CALL zgemm("N","N",lapw%nv(iintsp),ab_size,ab_size,CMPLX(1.0,0.0),ab,SIZE(ab,1),td%h_loc(:,:,n,isp),SIZE(td%h_loc,1),CMPLX(0.,0.),ab2,SIZE(ab2,1))
#endif
#ifdef _CUDA
CALL zgemm("N","N",lapw%nv(iintsp),ab_size,ab_size,CMPLX(1.0,0.0),ab_dev,SIZE(ab_dev,1),h_loc_dev,SIZE(td%h_loc,1),CMPLX(0.,0.),ab2_dev,SIZE(ab2_dev,1))
#else
CALL zgemm("N","N",lapw%nv(iintsp),ab_size,ab_size,CMPLX(1.0,0.0),ab,SIZE(ab,1),td%h_loc(0:,0:,n,isp),SIZE(td%h_loc,1),CMPLX(0.,0.),ab2,SIZE(ab2,1))
#endif
!Multiply for Hamiltonian
#ifdef _CUDA
ab1 = ab1_dev
ab1_dev=CONJG(ab1)
CALL zgemm("N","T",lapw%nv(iintsp),lapw%nv(jintsp),ab_size,chi,ab2_dev,SIZE(ab2_dev,1),ab1_dev,SIZE(ab1_dev,1),CMPLX(1.0,0.0),c_dev,SIZE(c_dev,1))
#else
CALL zgemm("N","T",lapw%nv(iintsp),lapw%nv(jintsp),ab_size,chi,conjg(ab2),SIZE(ab2,1),ab1,SIZE(ab1,1),CMPLX(1.0,0.0),hmat%data_c,SIZE(hmat%data_c,1))
#endif
ENDIF
ENDIF
END DO
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment