import numpy as np
from scipy import sparse, stats
from pprint import pprint

def test_matrix(m, k=4, t=1/16, symm=False, seed=None):
    """Creates a random sparse m-by-m matrix with normally distributed entries,
    represented as a dictionary mapping indices to values of nonzero entries.

    Keyword arguments:
    k: average number of off-diagonal entries per row
    t: average magnitude of off-diagonal entries
    symm: whether the matrix should be symmetric
    seed: pass in a fixed integer for replicable results"""

    p = k/m/2 if symm else k/m
    rng = np.random.default_rng(seed=seed)
    diag = [(i,i, rng.normal(0,1)) for i in range(m)]
    offd = [(i,j, rng.normal(0,t)) for i in range(m) for j in range(m)
            if rng.binomial(1, p)]
    if symm:
        offd += [(j,i, aij) for (i,j, aij) in offd]

    A_dict = {}
    for (i,j, aij) in diag + offd:
        A_dict[i,j] = A_dict.get((i,j), 0) + aij
    return A_dict

def matvec_dict(A_dict, x):
    """Performs matrix-vector multiplication with a matrix represented as a
    dictionary."""
    y = np.zeros(x.shape)
    for (i,j), aij in A_dict.items():
        y[i] += aij*x[j]
    return y

if __name__ == '__main__':
    A_dict = test_matrix(10)
    pprint(A_dict)
    x = np.random.normal(0,1, 10)
    print(x)
    print(matvec_dict(A_dict, x))
