/*
    -- MAGMA (version 2.0) --
       Univ. of Tennessee, Knoxville
       Univ. of California, Berkeley
       Univ. of Colorado, Denver
       @date

       @author Hartwig Anzt

       @generated from sparse/src/zpqmr_merge.cpp, normal z -> s, Fri Sep 19 13:51:38 2025
*/

#include "magmasparse_internal.h"

#define RTOLERANCE     lapackf77_slamch( "E" )
#define ATOLERANCE     lapackf77_slamch( "E" )


/**
    Purpose
    -------

    Solves a system of linear equations
       A * X = B
    where A is a general real matrix A.
    This is a GPU implementation of the preconditioned 
    Quasi-Minimal Residual method (QMR) using custom-designed kernels.

    Arguments
    ---------

    @param[in]
    A           magma_s_matrix
                input matrix A

    @param[in]
    b           magma_s_matrix
                RHS b

    @param[in,out]
    x           magma_s_matrix*
                solution approximation

    @param[in,out]
    solver_par  magma_s_solver_par*
                solver parameters
                
    @param[in]
    precond_par magma_s_preconditioner*
                preconditioner

    @param[in]
    queue       magma_queue_t
                Queue to execute in.

    @ingroup magmasparse_sgesv
    ********************************************************************/

extern "C" magma_int_t
magma_spqmr_merge(
    magma_s_matrix A, magma_s_matrix b, magma_s_matrix *x,
    magma_s_solver_par *solver_par,
    magma_s_preconditioner *precond_par,
    magma_queue_t queue )
{
    magma_int_t info = MAGMA_NOTCONVERGED;
    
    // prepare solver feedback
    solver_par->solver = Magma_QMR;
    solver_par->numiter = 0;
    solver_par->spmv_count = 0;
    
    
    // local variables
    float c_zero = MAGMA_S_ZERO, c_one = MAGMA_S_ONE;
    // solver variables
    float nom0, r0, res=0.0, nomb;
    float rho = c_one, rho1 = c_one, eta = -c_one , pds = c_one, 
                        thet = c_one, thet1 = c_one, epsilon = c_one, 
                        beta = c_one, delta = c_one, pde = c_one, rde = c_one,
                        gamm = c_one, gamm1 = c_one, psi = c_one;
    
    magma_int_t dofs = A.num_rows* b.num_cols;

    // need to transpose the matrix
    magma_s_matrix AT={Magma_CSR};
    
    // GPU workspace
    magma_s_matrix r={Magma_CSR}, r_tld={Magma_CSR},
                    v={Magma_CSR}, w={Magma_CSR}, wt={Magma_CSR},
                    d={Magma_CSR}, s={Magma_CSR}, z={Magma_CSR}, q={Magma_CSR}, 
                    p={Magma_CSR}, pt={Magma_CSR}, y={Magma_CSR},
                    vt={Magma_CSR}, yt={Magma_CSR}, zt={Magma_CSR};
    CHECK( magma_svinit( &r, Magma_DEV, A.num_rows, b.num_cols, c_zero, queue ));
    CHECK( magma_svinit( &r_tld, Magma_DEV, A.num_rows, b.num_cols, c_zero, queue ));
    CHECK( magma_svinit( &v, Magma_DEV, A.num_rows, b.num_cols, c_zero, queue ));
    CHECK( magma_svinit( &w, Magma_DEV, A.num_rows, b.num_cols, c_zero, queue ));
    CHECK( magma_svinit( &wt,Magma_DEV, A.num_rows, b.num_cols, c_zero, queue ));
    CHECK( magma_svinit( &d, Magma_DEV, A.num_rows, b.num_cols, c_zero, queue ));
    CHECK( magma_svinit( &s, Magma_DEV, A.num_rows, b.num_cols, c_zero, queue ));
    CHECK( magma_svinit( &z, Magma_DEV, A.num_rows, b.num_cols, c_zero, queue ));
    CHECK( magma_svinit( &q, Magma_DEV, A.num_rows, b.num_cols, c_zero, queue ));
    CHECK( magma_svinit( &p, Magma_DEV, A.num_rows, b.num_cols, c_zero, queue ));
    CHECK( magma_svinit( &pt,Magma_DEV, A.num_rows, b.num_cols, c_zero, queue ));
    CHECK( magma_svinit( &y, Magma_DEV, A.num_rows, b.num_cols, c_zero, queue ));
    CHECK( magma_svinit( &yt, Magma_DEV, A.num_rows, b.num_cols, c_zero, queue ));
    CHECK( magma_svinit( &vt, Magma_DEV, A.num_rows, b.num_cols, c_zero, queue ));
    CHECK( magma_svinit( &zt, Magma_DEV, A.num_rows, b.num_cols, c_zero, queue ));

    
    // solver setup
    CHECK(  magma_sresidualvec( A, b, *x, &r, &nom0, queue));
    solver_par->init_res = nom0;
    magma_scopy( dofs, r.dval, 1, r_tld.dval, 1, queue );   
    magma_scopy( dofs, r.dval, 1, vt.dval, 1, queue );  
    magma_scopy( dofs, r.dval, 1, wt.dval, 1, queue );   
     
    
    // transpose the matrix
    magma_smtransposeconjugate( A, &AT, queue );
    
    nomb = magma_snrm2( dofs, b.dval, 1, queue );
    if ( nomb == 0.0 ){
        nomb=1.0;
    }       
    if ( (r0 = nomb * solver_par->rtol) < ATOLERANCE ){
        r0 = ATOLERANCE;
    }
    solver_par->final_res = solver_par->init_res;
    solver_par->iter_res = solver_par->init_res;
    if ( solver_par->verbose > 0 ) {
        solver_par->res_vec[0] = (real_Double_t)nom0;
        solver_par->timing[0] = 0.0;
    }
    if ( nom0 < r0 ) {
        info = MAGMA_SUCCESS;
        goto cleanup;
    }
        // no precond: y = vt, z = wt
        // magma_scopy( dofs, vt.dval, 1, y.dval, 1, queue );
        // magma_scopy( dofs, wt.dval, 1, z.dval, 1, queue );
    CHECK( magma_s_applyprecond_left( MagmaNoTrans, A, vt, &y, precond_par, queue ));
    CHECK( magma_s_applyprecond_right( MagmaTrans, A, wt, &z, precond_par, queue ));

    psi = magma_ssqrt( magma_sdot( dofs, z.dval, 1, z.dval, 1, queue ));
    rho = magma_ssqrt( magma_sdot( dofs, y.dval, 1, y.dval, 1, queue ));
        // v = vt / rho
        // y = y / rho
        // w = wt / psi
        // z = z / psi
    magma_sqmr_8(  
    r.num_rows, 
    r.num_cols, 
    rho,
    psi,
    vt.dval,
    wt.dval,
    y.dval, 
    z.dval,
    v.dval,
    w.dval,
    queue );

    //Chronometry
    real_Double_t tempo1, tempo2;
    tempo1 = magma_sync_wtime( queue );
    
    solver_par->numiter = 0;
    // start iteration
    do
    {
        solver_par->numiter++;
        if( magma_s_isnan_inf( rho ) || magma_s_isnan_inf( psi ) ){
            info = MAGMA_DIVERGENCE;
            break;
        }
            // delta = z' * y;
        delta = magma_sdot( dofs, z.dval, 1, y.dval, 1, queue );
        if( magma_s_isnan_inf( delta ) ){
            info = MAGMA_DIVERGENCE;
            break;
        }
            // no precond: yt = y, zt = z
        // magma_scopy( dofs, y.dval, 1, yt.dval, 1, queue );
        // magma_scopy( dofs, z.dval, 1, zt.dval, 1, queue );
        CHECK( magma_s_applyprecond_right( MagmaNoTrans, A, y, &yt, precond_par, queue ));
        CHECK( magma_s_applyprecond_left( MagmaTrans, A, z, &zt, precond_par, queue ));

        
        if( solver_par->numiter == 1 ){
                // p = y;
                // q = z;
            magma_scopy( dofs, yt.dval, 1, p.dval, 1, queue );
            magma_scopy( dofs, zt.dval, 1, q.dval, 1, queue );
        }
        else{
            pde = psi * delta / epsilon;
            rde = rho * MAGMA_S_CONJ(delta/epsilon);
                // p = yt - pde * p
                // q = zt - rde * q
            magma_sqmr_2(  
            r.num_rows, 
            r.num_cols, 
            pde,
            rde,
            yt.dval,
            zt.dval,
            p.dval, 
            q.dval, 
            queue );
        }
        if( magma_s_isnan_inf( rho ) || magma_s_isnan_inf( psi ) ){
            info = MAGMA_DIVERGENCE;
            break;
        }

        CHECK( magma_s_spmv( c_one, A, p, c_zero, pt, queue ));
        solver_par->spmv_count++;
            // epsilon = q' * pt;
        epsilon = magma_sdot( dofs, q.dval, 1, pt.dval, 1, queue );
        beta = epsilon / delta;

        if( magma_s_isnan_inf( epsilon ) || magma_s_isnan_inf( beta ) ){
            info = MAGMA_DIVERGENCE;
            break;
        }
            // vt = pt - beta * v;
        magma_sqmr_7(  
        r.num_rows, 
        r.num_cols, 
        beta,
        pt.dval,
        v.dval,
        vt.dval,
        queue );
        
            // wt = A' * q - beta' * w;
        CHECK( magma_s_spmv( c_one, AT, q, c_zero, wt, queue ));
        solver_par->spmv_count++;
        magma_saxpy( dofs, - MAGMA_S_CONJ( beta ), w.dval, 1, wt.dval, 1, queue );  
            // no precond: z = wt
        // magma_scopy( dofs, wt.dval, 1, z.dval, 1, queue );
        CHECK( magma_s_applyprecond_right( MagmaTrans, A, wt, &z, precond_par, queue ));
            // no precond: y = vt
        // magma_scopy( dofs, vt.dval, 1, y.dval, 1, queue );
        CHECK( magma_s_applyprecond_left( MagmaNoTrans, A, vt, &y, precond_par, queue ));

        rho1 = rho;      
            // rho = norm(y);
        rho = magma_ssqrt( magma_sdot( dofs, y.dval, 1, y.dval, 1, queue ));
        
        thet1 = thet;        
        thet = rho / (gamm * MAGMA_S_MAKE( MAGMA_S_ABS(beta), 0.0 ));
        gamm1 = gamm;        
        
        gamm = c_one / magma_ssqrt(c_one + thet*thet);        
        eta = - eta * rho1 * gamm * gamm / (beta * gamm1 * gamm1);        

        if( magma_s_isnan_inf( thet ) || magma_s_isnan_inf( gamm ) || magma_s_isnan_inf( eta ) ){
            info = MAGMA_DIVERGENCE;
            break;
        }

        if( solver_par->numiter == 1 ){
                // d = eta * p + pds * d;
                // s = eta * pt + pds * d;
                // x = x + d;
                // r = r - s;
            magma_sqmr_4(  
            r.num_rows, 
            r.num_cols, 
            eta,
            p.dval,
            pt.dval,
            d.dval, 
            s.dval, 
            x->dval, 
            r.dval, 
            queue );
        }
        else{
                // pds = (thet1 * gamm)^2;
            pds = (thet1 * gamm) * (thet1 * gamm);
                // d = eta * p + pds * d;
                // s = eta * pt + pds * d;
                // x = x + d;
                // r = r - s;
            magma_sqmr_5(  
            r.num_rows, 
            r.num_cols, 
            eta,
            pds,
            p.dval,
            pt.dval,
            d.dval, 
            s.dval, 
            x->dval, 
            r.dval, 
            queue );
        }
            // psi = norm(z);
        psi = magma_ssqrt( magma_sdot( dofs, z.dval, 1, z.dval, 1, queue ) );
        
        res = magma_snrm2( dofs, r.dval, 1, queue );
        
        if ( solver_par->verbose > 0 ) {
            tempo2 = magma_sync_wtime( queue );
            if ( (solver_par->numiter)%solver_par->verbose == c_zero ) {
                solver_par->res_vec[(solver_par->numiter)/solver_par->verbose]
                        = (real_Double_t) res;
                solver_par->timing[(solver_par->numiter)/solver_par->verbose]
                        = (real_Double_t) tempo2-tempo1;
            }
        }
            // v = vt / rho
            // y = y / rho
            // w = wt / psi
            // z = z / psi
        magma_sqmr_8(  
        r.num_rows, 
        r.num_cols, 
        rho,
        psi,
        vt.dval,
        wt.dval,
        y.dval, 
        z.dval,
        v.dval,
        w.dval,
        queue );

        if ( res/nomb <= solver_par->rtol || res <= solver_par->atol ){
            break;
        }
    }
    while ( solver_par->numiter+1 <= solver_par->maxiter );
    
    tempo2 = magma_sync_wtime( queue );
    solver_par->runtime = (real_Double_t) tempo2-tempo1;
    float residual;
    CHECK(  magma_sresidualvec( A, b, *x, &r, &residual, queue));
    solver_par->iter_res = res;
    solver_par->final_res = residual;

    if ( solver_par->numiter < solver_par->maxiter && info == MAGMA_SUCCESS ) {
        info = MAGMA_SUCCESS;
    } else if ( solver_par->init_res > solver_par->final_res ) {
        if ( solver_par->verbose > 0 ) {
            if ( (solver_par->numiter)%solver_par->verbose == c_zero ) {
                solver_par->res_vec[(solver_par->numiter)/solver_par->verbose]
                        = (real_Double_t) res;
                solver_par->timing[(solver_par->numiter)/solver_par->verbose]
                        = (real_Double_t) tempo2-tempo1;
            }
        }
        info = MAGMA_SLOW_CONVERGENCE;
        if( solver_par->iter_res < solver_par->rtol*nomb ||
            solver_par->iter_res < solver_par->atol ) {
            info = MAGMA_SUCCESS;
        }
    }
    else {
        if ( solver_par->verbose > 0 ) {
            if ( (solver_par->numiter)%solver_par->verbose == c_zero ) {
                solver_par->res_vec[(solver_par->numiter)/solver_par->verbose]
                        = (real_Double_t) res;
                solver_par->timing[(solver_par->numiter)/solver_par->verbose]
                        = (real_Double_t) tempo2-tempo1;
            }
        }
        info = MAGMA_DIVERGENCE;
    }
    
cleanup:
    magma_smfree(&r, queue );
    magma_smfree(&r_tld, queue );
    magma_smfree(&v,  queue );
    magma_smfree(&w,  queue );
    magma_smfree(&wt, queue );
    magma_smfree(&d,  queue );
    magma_smfree(&s,  queue );
    magma_smfree(&z,  queue );
    magma_smfree(&q,  queue );
    magma_smfree(&p,  queue );
    magma_smfree(&zt, queue );
    magma_smfree(&vt, queue );
    magma_smfree(&yt, queue );
    magma_smfree(&pt, queue );
    magma_smfree(&y,  queue );
    magma_smfree(&AT, queue );
    
    solver_par->info = info;
    return info;
}   /* magma_sqmr */
