How to manually create a sparse matrix in Python
I have a text file containing data containing a sparse matrix with the following format:
0 234 345
0 236
0 345 365 465
0 12 35 379
The data is used for the classification task and each row can be considered a feature vector. The first value on each line represents the label, the following values represent the presence of individual features.
I am trying to create a sparse matrix with these values (use in a machine learning assignment with scikit learn). I found and read the scipy.sparse documentation , but I don't understand how to incrementally grow a sparse matrix with raw data like this.
The examples I've found so far show how to take a dense matrix and transform it, or how to create your own sparse matrix with contrived data, but there are no examples that helped me here. I found this related SO question ( Plotting and updating a sparse matrix in python using scipy ), but the example assumes you know the maximum COL, ROW sizes I don’t have, so the datatype is not appropriate.
So far, I have the following code to read the document and parse the values into something that seems reasonable:
def get_sparse_matrix():
matrix = []
with open("data.dat", 'r') as f:
for i, line in enumerate(f):
row = line.strip().split()
label = row[0]
features = entry[1:]
matrix.append([(i, col) for col in features])
sparse_matrix = #magic happens here
return sparse_matrix
So the questions:
- What is the appropriate sparse matrix type to use here?
- Am I heading in the right direction with the code I have?
Any help is greatly appreciated.
source to share
You can use coo_matrix()
:
import numpy as np
from scipy import sparse
data = """0 234 345
0 236
0 345 365 465
0 12 35 379"""
column_list = []
for line in data.split("\n"):
values = [int(x) for x in line.strip().split()[1:]]
column_list.append(values)
lengths = [len(row) for row in column_list]
cols = np.concatenate(column_list)
rows = np.repeat(np.arange(len(column_list)), lengths)
m = sparse.coo_matrix((np.ones_like(rows), (rows, cols)))
Here's the code to check the result:
np.where(m.toarray())
output:
(array([0, 0, 1, 2, 2, 2, 3, 3, 3]),
array([234, 345, 236, 345, 365, 465, 12, 35, 379]))
source to share