use std::ops; #[derive(Debug)] pub struct Matrix { row_count: usize, col_count: usize, default_val: T, rows: Vec>, } #[derive(Debug)] pub enum MatrixError { RowIndexOutOfBound(usize), ColIndexOutOfBound(usize), IncompatibleSize, } impl Matrix where T: Copy { pub fn new(row_count: usize, col_count: usize, default_val: T) -> Self { let rows = vec![vec![default_val; col_count]; row_count]; Matrix { row_count, col_count, default_val, rows } } pub fn new_with_size(matrix: &Self) -> Self { Matrix::new(matrix.row_count, matrix.col_count, matrix.default_val) } pub fn get(&self, row: usize, col: usize) -> Option<&T> { if row >= self.row_count || col >= self.col_count { return None; } Some(&self.rows[row][col]) } pub fn set(&mut self, val: T, row: usize, col: usize) -> Result<(), MatrixError> { if row >= self.row_count { return Err(MatrixError::RowIndexOutOfBound(row)); } if col >= self.col_count { return Err(MatrixError::ColIndexOutOfBound(col)); } self.rows[row][col] = val; Ok(()) } fn size_equal(&self, other: &Self) -> bool { self.row_count == other.row_count && self.col_count == other.col_count } } impl ops::Add> for Matrix where T: ops::Add, T: Copy { type Output = Result, MatrixError>; fn add(self, rhs: Matrix) -> Self::Output { if !&self.size_equal(&rhs) { return Err(MatrixError::IncompatibleSize); } let mut result_matrix = Matrix::new_with_size(&self); for row in 0..self.row_count { for col in 0..self.col_count { // Since the sizes of the matrices are known, we can ignore errors let val = self.rows[row][col] + rhs.rows[row][col]; result_matrix.set(val, row, col).unwrap(); } } Ok(result_matrix) } } impl ops::Sub> for Matrix where T: ops::Sub, T: Copy { type Output = Result, MatrixError>; fn sub(self, rhs: Matrix) -> Self::Output { if !&self.size_equal(&rhs) { return Err(MatrixError::IncompatibleSize); } let mut result_matrix = Matrix::new_with_size(&self); for row in 0..self.row_count { for col in 0..self.col_count { // Since the sizes of the matrices are known, we can ignore errors let val = self.rows[row][col] - rhs.rows[row][col]; result_matrix.set(val, row, col).unwrap(); } } Ok(result_matrix) } } impl ops::Mul for Matrix where T: ops::Mul, T: Copy { type Output = Matrix; fn mul(self, rhs: T) -> Self::Output { let mut result_matrix = Matrix::new_with_size(&self); for row in 0..self.row_count { for col in 0..self.col_count { let val = self.rows[row][col] * rhs; result_matrix.set(val, row, col).unwrap(); } } result_matrix } } impl ops::Mul> for Matrix where T: ops::Add, T: ops::Mul, T: Copy { type Output = Result, MatrixError>; fn mul(self, rhs: Matrix) -> Self::Output { if self.col_count != rhs.row_count { return Err(MatrixError::IncompatibleSize); } let mut result_matrix = Matrix::new(self.row_count, rhs.col_count, self.default_val); for row in 0..result_matrix.row_count { for col in 0..result_matrix.col_count { let mut val = self.default_val; for i in 0..self.col_count { let val_i = self.rows[row][i]; let val_j = rhs.rows[i][col]; val = val + val_i * val_j; } result_matrix.set(val, row, col).unwrap(); } } Ok(result_matrix) } }