描述
本文章试图将之前写的一些C++中的矩阵运算进行归纳整理,并且试图提供像Python的numpy库一样的矩阵运算功能。
并且为了降低使用成本,将这些功能写成模板头文件,在使用的时候直接导入该头文件即可。这也算是写C++热衷于造轮子的乐趣吧。
TODO 尝试加入CUDA
TODO 尝试使用纯指针实现模板类
代码
/**
******************************************************************************
矩阵计算的模板类
使用std::vector实现
* @file : Mat.h
* @author : Richardo Gu
* @date : 2024/6/18
******************************************************************************
*/
#ifndef Mat_H
#define Mat_H
#include <vector>
#include <iostream>
/**
* 矩阵转置
*/
template<typename T>
std::vector<std::vector<T>> transpose(std::vector<std::vector<T>> matrix) {
std::vector<std::vector<T>> array;
std::vector<T> tempArr;
for (int i = 0; i < matrix[0].size(); ++i) {
for (int j = 0; j < matrix.size(); ++j) {
tempArr.push_back(matrix[j][i]);
}
array.push_back(tempArr);
tempArr.erase(tempArr.begin(), tempArr.end());
}
return array;
}
/**
* 一维矩阵加法
*/
template<typename T>
std::vector<T> operator+(std::vector<T> const& arrA, std::vector<T> const& arrB) {
int length = arrA.size();
std::vector<T> result(length);
for (int i = 0; i < length; i++) {
result[i] = arrA[i] + arrB[i];
}
return result;
}
/**
* 一维矩阵+=运算
*/
template<typename T>
std::vector<T>& operator+=(std::vector<T>& arr_a, const std::vector<T>& arr_b) {
for (int i = 0; i < arr_a.size(); i++) {
arr_a[i] += arr_b[i];
}
return arr_a;
}
/**
* 二维矩阵加法
*/
template<typename T>
std::vector<std::vector<T>>
operator+(std::vector<std::vector<T>> const& arr_a, std::vector<std::vector<T>> const& arr_b) {
int length = arr_a.size();
std::vector<std::vector<T>> result(length);
for (int i = 0; i < length; i++) {
result[i] = arr_a[i] + arr_b[i];
}
return result;
}
/**
* 二维矩阵+=
*/
template<typename T>
std::vector<std::vector<T>>& operator+=(std::vector<std::vector<T>>& arr_a, std::vector<std::vector<T>> const& arr_b) {
int length = arr_a.size();
for (int i = 0; i < length; i++) {
arr_a[i] += arr_b[i];
}
return arr_a;
}
/**
* 一维矩阵减法
*/
template<typename T>
std::vector<T> operator-(std::vector<T> const& arr_a) {
int length = arr_a.size();
std::vector<T> result(length);
for (int i = 0; i < length; i++) {
result[i] = -arr_a[i];
}
return result;
}
/**
* 一维矩阵乘法
*/
template<typename T>
T operator*(std::vector<T> const& arr_a, std::vector<T> const& arr_b) {
T result = 0;
for (int i = 0; i < arr_a.size(); i++) {
result += arr_a[i] * arr_b[i];
}
return result;
}
/**
* 二维矩阵乘一维矩阵
*/
template<typename T>
std::vector<T> operator*(std::vector<std::vector<T>> const& arr_a, std::vector<T> const& arr_b) {
std::vector<T> result(arr_a.size());
for (int i = 0; i < arr_a.size(); i++) {
result[i] = arr_a[i] * arr_b;
}
return result;
}
/**
* 二维矩阵乘二维矩阵
*/
template<typename T>
std::vector<std::vector<T>>
operator*(std::vector<std::vector<T>> const& arr_a, std::vector<std::vector<T>> const& arr_b) {
int row = arr_a.size();
int col = arr_b[0].size();
std::vector<std::vector<T>> temp = transpose(arr_b);
std::vector<std::vector<T>> result = std::vector<std::vector<double>>(row, std::vector<double>(col));
for (int i = 0; i < row; i++) {
for (int j = 0; j < col; j++) {
result[i][j] = arr_a[i] * temp[j];
}
}
return result;
}
template<typename T>
std::ostream& operator<<(std::ostream& stream, const std::vector<T>& arr) {
stream << "[";
for (int i = 0; i < arr.size(); i++) {
stream << arr[i] << (i == arr.size() - 1 ? "" : ",");
}
return stream << "]" << std::endl;
}
/**
* 矩阵拼接 对v1进行修改
*/
template<typename T>
std::vector<T>& concatenate(std::vector<T>& v1, const std::vector<T> v2) {
v1.insert(v1.end(), v2.begin(), v2.end());
return v1;
}
template<typename T>
std::pair<bool, int> find(const std::vector<T>& arr, T data) {
int idx = 0;
for (const auto& i : arr) {
if (i == data) {
return {true, idx};
}
idx++;
}
return {false, 0}; // 没有找到则返回负数
}
/**
* 矩阵删除元素
*/
template<typename T>
std::vector<T> deleteObj(std::vector<T> arr, const std::vector<int>& obj) {
int newSize = arr.size() - obj.size();
if (newSize == 0) return arr;
if (newSize < 0) throw std::runtime_error("data to be delete oversize than origin array");
std::vector<T> newArr(newSize);
int deletedNum = 0;
for (int i = 0; i < arr.size(); i++) {
if (find(obj, i).first) {
deletedNum++;
continue;
}
newArr[i - deletedNum] = arr[i];
}
return newArr;
}
#endif //Mat_H