3 min read

稀疏矩阵

只存储在矩阵中极少数的非零元素,所以,我们需要保存元素的下标与值。可以使用一个三元组<row, column, value>来唯一表示一个元素。

#类定义

#include <iostream.h>
#include <stdlib.h>
const int DefaultSize = 100;
template <class T>
struct Trituple {
	int row, col;
	T value;
	Trituple<T>& operator = (Trituple<T>& x) {
		row = x.row;
		col = x.col;
		value = x.value;
	}
};

template <class T>
class SparseMatrix {
friend ostream& operator << (ostream& out, SparseMatrix<T>& M);
friend istream& operator >> (istream& in, SparseMatrix<T>& M);

public:
	SparseMatrix(int maxSz = DefaultSize);
	SparseMatrix(SparseMatrix<T>& x);
	~SparseMatrix() {
		delete []smArray;
	}
	SparseMatrix<T>& operator = (SparseMatrix<T>& x);
	SparseMatrix<T> Transpose();
	SparseMatrix<T> Add(SparseMatrix<T>& b);
	SparseMatrix<T> Multiply(SparseMatrix<T>& b);
	
private:
	int Rows, Cols, Terms;
	Trituple<T> *smArray;
	int maxTerms;
};

template <class T>
SparseMatrix<T>::SparseMatrix(int maxSz): maxTerms(maxSz){
	if (maxSz < 1) {
		cerr << "矩阵初始化出错!" << endl;
		exit(1);
	}
	smArray = new Trituple<T>[maxSz];
	if (smArray == NULL) {
		cerr << "存储出错!" << endl;
		exit(1);
	}
	Rows = Cols = Terms = 0;
};

SparseMatrix<T>::SparseMatrix(SparseMatrix<T>& x) {
	Rows = x.Rows;
	Cols = x.Cols;
	Terms = x.Terms;
	maxTerm = x.maxTerms;
	smArray = new Trituple<T>[maxTerms];
	if (smArray == NULL) {
		cerr << "存储分配出错!" << endl;
		exit(1);
	}
	for (int i = 0; i < Terms; i ++) {
		smArray[i] = x.smArray[i];
	}
};

#转置

最简单的方法就是把每个元素的行与列互换,然后重新排序。

template <class T>
SparseMatrix<T> SparseMatrix<T>::Transpose() {
	SparseMatrix<T> b(maxTerms);
	b.Rows = Cols;
	b.Cols = Rows;
	b.Terms = Terms;
	if (Terms > 0) {
		int k, i, CurrentB = 0;
		for (k = 0; k < Cols; k ++) {
			for (i = 0; i < Terms; i ++) {
				if (smArray[i].col == k) {
					b.smArray[CurrentB].row = k;
					b.smArray[CurrentB].col = smArray[i].row;
					b.smArray[CurrentB].value = smArray[i].value;
					CurrentB ++;
				}
			}
		}
	}
	
	return b;
};

快速转置算法。

template <class T>
SparseMatrix<T> SparseMatrix<T>::FastTranspose() {
	int *rowSize = new int[Cols];
	int *rowStart = new int[Cols];
	SparseMatrix<T> b(maxTerms);
	b.Rows = Cols;
	b.Cols = Rows;
	b.Terms = Terms;
	if (Terms > 0) {
		int i, j;
		for (i = 0; i < Cols; i ++) {
			rowSize[i] = 0;
		}
		for (i = 0; i < Terms; i ++) {
			rowSize[smArray[i].col] ++;
		}
		rowStart[0] = 0;
		for (i = 1; i < Cols; i ++) {
			rowStart[i] = rowStart[i - 1] + rowSize[i - 1];
		}
		for (i = 0; i < Terms; i ++) {
			j = rowStart[smArray[i].col];
			b.smArray[j].row = smArray[i].col;
			b.smArray[j].col = smArray[i].row;
			b.smArray[j].value = smArray[i].value;
			rowStart[smArray[i].col] ++;
		}
	}
	delete[] rowSize;
	delete[] rowStart;
	
	return b;
};

#加法

template <class T>
SparseMatrix<T> SparseMatrix<T>::Add(SparseMatrix<T> b) {
	SparseMatrix<Type> result(Rows, Cols);
	if (Rows != b.Rows || Cols != b.Cols) {
		cout << "Incompatible matrices" << endl;
		return result;
	}
	int i = 0, j = 0, index_a, index_b;
	Result.Terms = 0;
	while (i < Terms && j < b.Terms) {
		index_a = Cols * smArray[i].row + smArray[i].col;
		index_b = Cols * b.smArray[j].row + b.smArray[j].col;
		if (index_a < index_b) {
			Result.smArray[Result.Terms] = smArray[i];
			i ++;
		} else if (index_a > index_b) {
			Result.smArray[Result.Terms] = b.smArray[j];
			j ++;
		} else {
			Result.smArray[Result.Terms] = smArray[i];
			Result.smArray[Result.Terms].value = smArray[i].value + b.smArray[j].value;
			i ++;
			j ++;
		}
		Result.Terms ++;
	}
	for (; i < Terms; i ++) {
		Result.smArray[Result.Terms] = smArray[i];
		Result.Terms ++;
	}
	for (; j < b.Terms; j ++) {
		Result.smArray[Result.Terms] = b.smArray[i];
		Result.Terms ++;
	}
};

#乘法

template <class T>
SparseMatrix<T> SparseMatrix<T>::Multiply(SparseMatrix<T>& b) {
	SparseMatrix<T> result(Rows, b.Cols);
	if (Cols != b.Rows) {
		cerr << "Incompatible matrices" << endl;
		return result;
	}
	if (Terms == maxTerms || b.Terms == maxTerms) {
		cerr << "One additional space in a or b needed" << endl;
		return result;
	}
	int *rowSize = new int[b.Rows];
	int *rowStart = new int[b.Rows + 1];
	T *temp = new T[b.Cols];
	int i, Current, lastInResult, RowA, ColA, ColB;
	for (i = 0; i < b.Rows; i ++) {
		rowSize[i] = 0;
	}
	for (i = 0; i < b.Terms; i ++) {
		rowSize[smArray[i].row] ++;
	}
	rowStart[0] = 0;
	for (i = 1; i <= b.Rows; i ++) {
		rowStart[i] = rowStart[i - 1] + rowSize[i - 1];
	}
	Current = 0;
	lastInResult = -1;
	while (Current < Terms) {
		RowA = smArray[Current].row;
		for (i = 0; i < b.Cols; i ++) {
			temp[i] = 0;
		}
		while (Current < Terms && smArray[Current].row == RowA) {
			ColA = smArray[Current].col;
			for (i = rowStart[ColA]; i < rowStart[ColA + 1]; i ++) {
				ColB = b.smArray[i].col;
				temp[ColB] += smArray[Current].value * b.smArray[i].value;
			}
			Current ++;
		}
		for (i = 0; i < b.Cols; i ++) {
			if (temp[i] != 0) {
				lastInResult ++;
				result.smArray[lastInResult].row = RowA;
				result.smArray[lastInResult].col = i;
				result.smArray[lastInResult].value = temp[i];
			}
		}
		result.Rows = Rows;
		result.Cols = b.Cols;
		result.Terms = lastInResult + 1;
	}
	delete[] rowSize;
	delete[] rowStart;
	delete[] temp;
	
	return result;
};