hsmt_nonsph.F90 14.2 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12
MODULE m_hsmt_nonsph
#define CPP_BLOCKSIZE 64
  !      USE m_juDFT
  !$     USE omp_lib

  !TODO:
  !  Check what can be done in l_noco=.true. case in terms of use of zgemm or aa_block
  !  Check what happens in case of CPP_INVERSION -> real matrix a

  IMPLICIT NONE
CONTAINS
  SUBROUTINE hsmt_nonsph(DIMENSION,atoms,sym,SUB_COMM, n_size,n_rank,input,isp,nintsp,&
13
       hlpmsize,noco,l_socfirst, lapw, cell,tlmplm, fj,gj,gk,vk,oneD,l_real,aa_r,aa_c)
14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43

#include"cpp_double.h"
    USE m_constants, ONLY : tpi_const
    USE m_ylm
    USE m_hsmt_spinor
    USE m_hsmt_hlptomat
    USE m_types
    IMPLICIT NONE
    TYPE(t_dimension),INTENT(IN):: DIMENSION
    TYPE(t_oneD),INTENT(IN)     :: oneD
    TYPE(t_input),INTENT(IN)    :: input
    TYPE(t_noco),INTENT(IN)     :: noco
    TYPE(t_sym),INTENT(IN)      :: sym
    TYPE(t_cell),INTENT(IN)     :: cell
    TYPE(t_atoms),INTENT(IN)    :: atoms
    TYPE(t_lapw),INTENT(INOUT)  :: lapw !lapw%nv_tot is updated

    !     ..
    !     .. Scalar Arguments ..
    INTEGER, INTENT (IN) :: nintsp,isp
    INTEGER, INTENT (IN) :: SUB_COMM,n_size,n_rank 
    INTEGER, INTENT (IN) :: hlpmsize
    LOGICAL, INTENT (IN) :: l_socfirst
    !     ..
    !     .. Array Arguments ..
    TYPE(t_tlmplm),INTENT(IN)::tlmplm
    REAL, INTENT(IN)     :: fj(:,0:,:,:),gj(:,0:,:,:)
    REAL,INTENT(IN)      :: gk(:,:,:),vk(:,:,:)
    !-odim
    !+odim
44 45 46
    LOGICAL, INTENT(IN)     :: l_real
    REAL,    INTENT (INOUT) :: aa_r(:)!(matsize)
    COMPLEX, INTENT (INOUT) :: aa_c(:)
47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70
    COMPLEX,PARAMETER :: one=CMPLX(1.0,0.0),zero=CMPLX(0.0,0.0)

    !     ..
    !     .. Local Scalars ..
    INTEGER :: i,iii,ii,ij,im,in,k,ki,kj,l,ll1,lm,lmp,lp,jd,m
    INTEGER :: mp,n,na,nn,np,kjmax,iintsp,jintsp
    INTEGER :: nc ,kii,spin2,ab_dim,lnonsphd,bsize,bsize2,kb
    REAL    :: th,invsfct
    COMPLEX :: term,chi11,chi21,chi22,chihlp


    !     ..
    !     .. Local Arrays ..
    COMPLEX,ALLOCATABLE :: aa_block(:,:)
    COMPLEX,ALLOCATABLE :: dtd(:,:),dtu(:,:),utd(:,:),utu(:,:)
    REAL   :: bmrot(3,3),gkrot(DIMENSION%nvd,3),vmult(3),v(3)
    COMPLEX:: ylm( (atoms%lmaxd+1)**2 ),chi(2,2)
    !     ..
    COMPLEX, ALLOCATABLE :: a(:,:,:),b(:,:,:),ax(:,:),bx(:,:)
    COMPLEX, ALLOCATABLE :: c_ph(:,:)
    COMPLEX,ALLOCATABLE :: aahlp(:),aa_tmphlp(:)
    INTEGER :: n_threads,thread,blocksize,maxloop
    INTEGER,ALLOCATABLE :: start_thread(:),stop_thread(:)

71 72

    
73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93
    lnonsphd=MAXVAL(atoms%lnonsph)*(MAXVAL(atoms%lnonsph)+2)
    ALLOCATE(dtd(0:lnonsphd,0:lnonsphd),utd(0:lnonsphd,0:lnonsphd),dtu(0:lnonsphd,0:lnonsphd),utu(0:lnonsphd,0:lnonsphd))
    !Decide how to distribute the work

    IF ( noco%l_noco .AND. (.NOT. noco%l_ss) ) ALLOCATE ( aahlp(hlpmsize),aa_tmphlp(hlpmsize) )

    ALLOCATE(aa_block(CPP_BLOCKSIZE,MAXVAL(lapw%nv)))

    ab_dim=1
    IF (noco%l_ss) ab_dim=2
    ALLOCATE(a(DIMENSION%nvd,0:DIMENSION%lmd,ab_dim),b(DIMENSION%nvd,0:DIMENSION%lmd,ab_dim))
    ALLOCATE(ax(DIMENSION%nvd,0:DIMENSION%lmd),bx(DIMENSION%nvd,0:DIMENSION%lmd))
    ALLOCATE(c_ph(DIMENSION%nvd,ab_dim))

    ntyploop: DO n=1,atoms%ntype
       IF (noco%l_noco) THEN
          IF (.NOT.noco%l_ss) aahlp=CMPLX(0.0,0.0)
          IF (.NOT.noco%l_ss) aa_tmphlp=CMPLX(0.0,0.0)
          CALL hsmt_spinor(isp,n, noco,input, chi, chi11, chi21, chi22)
       ENDIF
       DO nn = 1,atoms%neq(n)
94 95
	  a=0.0
	  b=0.0
96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208
          na = SUM(atoms%neq(:n-1))+nn
          IF (atoms%lnonsph(n)<0) CYCLE ntyploop
          IF ((atoms%invsat(na)==0) .OR. (atoms%invsat(na)==1)) THEN
             IF (atoms%invsat(na)==0) invsfct = 1
             IF (atoms%invsat(na)==1) invsfct = 2
             np = sym%invtab(atoms%ngopr(na))
             IF (oneD%odi%d1) np = oneD%ods%ngopr(na)
             !--->       loop over interstitial spins
             DO iintsp = 1,nintsp
                IF (noco%l_constr.OR.l_socfirst) THEN
                   spin2=isp
                ELSE
                   spin2=iintsp
                ENDIF
                !--->          set up phase factors
                DO k = 1,lapw%nv(iintsp)
                   th= DOT_PRODUCT((/lapw%k1(k,iintsp),lapw%k2(k,iintsp),lapw%k3(k,iintsp)/)+(iintsp-1.5)*noco%qss,atoms%taual(:,na))
                   c_ph(k,iintsp) = CMPLX(COS(tpi_const*th),-SIN(tpi_const*th))
                END DO

                IF (np==1) THEN
                   gkrot( 1:lapw%nv(iintsp),:) = gk( 1:lapw%nv(iintsp),:,iintsp)
                ELSE
                   IF (oneD%odi%d1) THEN
                      bmrot=MATMUL(oneD%ods%mrot(:,:,np),cell%bmat)
                   ELSE
                      bmrot=MATMUL(1.*sym%mrot(:,:,np),cell%bmat)
                   END IF
                   DO k = 1,lapw%nv(iintsp)
                      !-->  apply the rotation that brings this atom into the
                      !-->  representative (this is the definition of ngopr(na)
                      !-->  and transform to cartesian coordinates
                      v(:) = vk(k,:,iintsp)
                      gkrot(k,:) = MATMUL(TRANSPOSE(bmrot),v)
                   END DO
                END IF
                DO k = 1,lapw%nv(iintsp)
                   !-->    generate spherical harmonics
                   vmult(:) =  gkrot(k,:)
                   CALL ylm4(atoms%lnonsph(n),vmult,ylm)
                   !-->  synthesize the complex conjugates of a and b
                   DO l = 0,atoms%lnonsph(n)
                      ll1 = l* (l+1)
                      DO m = -l,l
                         term = c_ph(k,iintsp)*ylm(ll1+m+1)
                         a(k,ll1+m,iintsp) = fj(k,l,n,spin2)*term
                         b(k,ll1+m,iintsp) = gj(k,l,n,spin2)*term
                      END DO
                   END DO
                ENDDO !k-loop
                !--->       end loop over interstitial spin
             ENDDO
             !--->       loops over the interstitial spin
             DO iintsp = 1,nintsp

                DO jintsp = 1,iintsp

                   jd = 1 ; IF (noco%l_noco) jd = isp
                   !--->       loop over l',m'
                   utu=0.0;utd=0.0;dtu=0.0;dtd=0.0
!!$OMP PARALLEL DO DEFAULT(NONE) PRIVATE(lp,mp,lmp,l,m,lm,in,utu,dtu,utd,dtd,im,k) &
!!$OMP SHARED(tlmplm,invsfct,lnonsph,nv,jintsp,jd,n)
                   DO lmp=0,atoms%lnonsph(n)*(atoms%lnonsph(n)+2)
                      lp=FLOOR(SQRT(1.0*lmp))
                      mp=lmp-lp*(lp+1)
                      IF (lp>atoms%lnonsph(n).OR.ABS(mp)>lp) STOP "BUG"
                      !--->             loop over l,m
                      DO l = 0,atoms%lnonsph(n)
                         DO m = -l,l
                            lm = l* (l+1) + m
                            in = tlmplm%ind(lmp,lm,n,jd)
                            IF (in/=-9999) THEN
                               IF (in>=0) THEN
                                  utu(lm,lmp) =CONJG(tlmplm%tuu(in,n,jd))*invsfct
                                  dtu(lm,lmp) =CONJG(tlmplm%tdu(in,n,jd))*invsfct
                                  utd(lm,lmp) =CONJG(tlmplm%tud(in,n,jd))*invsfct
                                  dtd(lm,lmp) =CONJG(tlmplm%tdd(in,n,jd))*invsfct
                               ELSE
                                  im = -in
                                  utu(lm,lmp) =tlmplm%tuu(im,n,jd)*invsfct
                                  dtu(lm,lmp) =tlmplm%tud(im,n,jd)*invsfct
                                  utd(lm,lmp) =tlmplm%tdu(im,n,jd)*invsfct
                                  dtd(lm,lmp) =tlmplm%tdd(im,n,jd)*invsfct
                               END IF
                               !--->    update ax, bx

                            END IF
                         END DO
                      END DO
                   ENDDO
!!$OMP END PARALLEL DO
                   lmp=atoms%lnonsph(n)*(atoms%lnonsph(n)+2)
                   !ax(:nv(jintsp),0:lmp)=(matmul(a(:nv(jintsp),0:lmp,jintsp),utu(0:lmp,0:lmp))+matmul(b(:nv(jintsp),0:lmp,jintsp),utd(0:lmp,0:lmp)))
                   !bx(:nv(jintsp),0:lmp)=(matmul(a(:nv(jintsp),0:lmp,jintsp),dtu(0:lmp,0:lmp))+matmul(b(:nv(jintsp),0:lmp,jintsp),dtd(0:lmp,0:lmp)))

                   CALL zgemm("N","N",lapw%nv(jintsp),lmp+1,lmp+1,one,a(1,0,jintsp),SIZE(a,1),utu(0,0),SIZE(utu,1),zero,ax,SIZE(ax,1))
                   CALL zgemm("N","N",lapw%nv(jintsp),lmp+1,lmp+1,one,b(1,0,jintsp),SIZE(a,1),utd(0,0),SIZE(utu,1),one,ax,SIZE(ax,1))

                   CALL zgemm("N","N",lapw%nv(jintsp),lmp+1,lmp+1,one,a(1,0,jintsp),SIZE(a,1),dtu(0,0),SIZE(utu,1),zero,bx,SIZE(ax,1))
                   CALL zgemm("N","N",lapw%nv(jintsp),lmp+1,lmp+1,one,b(1,0,jintsp),SIZE(a,1),dtd(0,0),SIZE(utu,1),one,bx,SIZE(ax,1))

                   !
                   !--->             update hamiltonian and overlap matrices
                   nc = 0
                   IF ( noco%l_noco .AND. (n_size>1) ) THEN
                      lapw%nv_tot = lapw%nv(1) + lapw%nv(2)
                   ELSE
                      lapw%nv_tot = lapw%nv(iintsp)
                   ENDIF
                   kii=n_rank
                   DO WHILE(kii<lapw%nv_tot)
                      !DO kii =  n_rank, nv_tot-1, n_size
                      ki = MOD(kii,lapw%nv(iintsp)) + 1
209
                      bsize=MIN(SIZE(aa_block,1),(lapw%nv(iintsp)-ki)/n_size+1) !Either use maximal blocksize or number of rows left to calculate
210
                      IF (bsize<1) EXIT !nothing more to do here
211
                      bsize2=bsize*n_size
212
                      bsize2=min(bsize2,lapw%nv(iintsp)-ki+1)
213 214 215 216 217 218 219 220 221
                      !aa_block(:bsize,:ki+bsize2-1)=matmul(a(ki:ki+bsize2-1:n_size,0:lmp,iintsp),conjg(transpose(ax(:ki+bsize2-1,0:lmp))))+ &
                      !                              matmul(b(ki:ki+bsize2-1:n_size,0:lmp,iintsp),conjg(transpose(bx(:ki+bsize2-1,0:lmp))))
                      IF (n_size==1) THEN !Make this a special case to avoid copy-in of a array
                         call zgemm("N","C",bsize,ki+bsize2-1,lmp+1,one,a(ki,0,iintsp),SIZE(a,1),ax(1,0),SIZE(ax,1),zero,aa_block,SIZE(aa_block,1))
                         call zgemm("N","C",bsize,ki+bsize2-1,lmp+1,one,b(ki,0,iintsp),SIZE(a,1),bx(1,0),SIZE(ax,1),one ,aa_block,SIZE(aa_block,1))
                      ELSE
                         CALL zgemm("N","C",bsize,ki+bsize2-1,lmp+1,one,a(ki:ki+bsize2-1:n_size,0:lmp,iintsp),SIZE(a(ki:ki+bsize2-1:n_size,0:lmp,iintsp),1),ax(1,0),SIZE(ax,1),zero,aa_block,SIZE(aa_block,1))
                         CALL zgemm("N","C",bsize,ki+bsize2-1,lmp+1,one,b(ki:ki+bsize2-1:n_size,0:lmp,iintsp),SIZE(a(ki:ki+bsize2-1:n_size,0:lmp,iintsp),1),bx(1,0),SIZE(ax,1),one,aa_block,SIZE(aa_block,1))
                      ENDIF
222 223 224 225 226 227 228 229 230 231
                      DO kb=1,bsize
                         IF ( noco%l_noco .AND. (.NOT. noco%l_ss) ) THEN
                            nc = 1+kii/n_size
                            ii = nc*(nc-1)/2*n_size-(nc-1)*(n_size-n_rank-1)
                            IF ( (n_size==1).OR.(kii+1<=lapw%nv(1)) ) THEN    !
                               aahlp(ii+1:ii+ki) = aahlp(ii+1:ii+ki)+MATMUL(CONJG(ax(:ki,:lmp)),a(ki,:,iintsp))+MATMUL(CONJG(bx(:ki,:lmp)),b(ki,:lmp,iintsp))
                            ELSE                    ! components for <2||2> block unused
                               aa_tmphlp(:ki) = MATMUL(CONJG(ax(:ki,:lmp)),a(ki,:lmp,iintsp))+MATMUL(CONJG(bx(:ki,:DIMENSION%lmd)),b(ki,:lmp,iintsp))
                               !--->                   spin-down spin-down part
                               ij = ii + lapw%nv(1)
232
                               aa_c(ij+1:ij+ki)=aa_c(ij+1:ij+ki)+chi22*aa_tmphlp(:ki)
233 234
                               !--->                   spin-down spin-up part, lower triangle
                               ij =  ii
235
                               aa_c(ij+1:ij+ki)=aa_c(ij+1:ij+ki)+chi21*aa_tmphlp(:ki)
236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255
                            ENDIF
                            !-||
                         ELSEIF ( noco%l_noco .AND. noco%l_ss ) THEN
                            IF ( iintsp==1 .AND. jintsp==1 ) THEN
                               !--->                      spin-up spin-up part
                               kjmax = ki
                               chihlp = chi11
                               ii = (ki-1)*(ki)/2
                            ELSEIF ( iintsp==2 .AND. jintsp==2 ) THEN
                               !--->                      spin-down spin-down part
                               kjmax = ki
                               chihlp = chi22
                               ii = (lapw%nv(1)+atoms%nlotot+ki-1)*(lapw%nv(1)+atoms%nlotot+ki)/2+&
                                    lapw%nv(1)+atoms%nlotot
                            ELSE
                               !--->                      spin-down spin-up part
                               kjmax = lapw%nv(1)
                               chihlp = chi21
                               ii = (lapw%nv(1)+atoms%nlotot+ki-1)*(lapw%nv(1)+atoms%nlotot+ki)/2
                            ENDIF
256
                            aa_c(ii+1:ii+kjmax) = aa_c(ii+1:ii+kjmax) + chihlp*&
257 258 259 260
                                 (MATMUL(CONJG(ax(:kjmax,:lmp)),a(ki,:,iintsp))+MATMUL(CONJG(bx(:kjmax,:lmp)),b(ki,:lmp,iintsp)))
                         ELSE
                            nc = 1+kii/n_size
                            ii = nc*(nc-1)/2*n_size- (nc-1)*(n_size-n_rank-1)
261 262 263 264 265
                            if (l_real) THEN
                               aa_r(ii+1:ii+ki) = aa_r(ii+1:ii+ki) + aa_block(kb,:ki)
                            ELSE
                               aa_c(ii+1:ii+ki) = aa_c(ii+1:ii+ki) + aa_block(kb,:ki)
                            endif
266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281
                            !print*,ii,ki,kb
                            !                           IF (.not.apw(l)) THEN
                            !aa(ii+1:ii+ki) = aa(ii+1:ii+ki) + b(ki,lmp,iintsp)*bx(:ki)
                            !                           ENDIF
                         ENDIF
                         ki=ki+n_size
                         kii=kii+n_size
                      ENDDO
                      !--->             end loop over ki
                   END DO
                   !--->       end loops over interstitial spin
                ENDDO
             ENDDO
          ENDIF              ! atoms%invsat(na) = 0 or 1
          !--->    end loop over equivalent atoms
       END DO
282
       IF ( noco%l_noco .AND. (.NOT. noco%l_ss) ) CALL hsmt_hlptomat(atoms%nlotot,lapw%nv,sub_comm,chi11,chi21,chi22,aahlp,aa_c)
283 284 285 286 287 288 289 290
       !---> end loop over atom types (ntype)
    ENDDO ntyploop

    RETURN
  END SUBROUTINE hsmt_nonsph


END MODULE m_hsmt_nonsph