使用java写的矩阵乘法实例(Strassen算法)
Strassen算法于1969年由德国数学家Strassen提出,该方法引入七个中间变量,每个中间变量都只需要进行一次乘法运算。而朴素算法却需要进行8次乘法运算。
原理
Strassen算法的原理如下所示,使用sympy验证Strassen算法的正确性
importsympyass A=s.Symbol("A") B=s.Symbol("B") C=s.Symbol("C") D=s.Symbol("D") E=s.Symbol("E") F=s.Symbol("F") G=s.Symbol("G") H=s.Symbol("H") p1=A*(F-H) p2=(A+B)*H p3=(C+D)*E p4=D*(G-E) p5=(A+D)*(E+H) p6=(B-D)*(G+H) p7=(A-C)*(E+F) print(A*E+B*G,(p5+p4-p2+p6).simplify()) print(A*F+B*H,(p1+p2).simplify()) print(C*E+D*G,(p3+p4).simplify()) print(C*F+D*H,(p1+p5-p3-p7).simplify())
复杂度分析
$$f(N)=7\timesf(\frac{N}{2})=7^2\timesf(\frac{N}{4})=...=7^k\timesf(\frac{N}{2^k})$$
最终复杂度为$7^{log_2N}=N^{log_27}$
java矩阵乘法(Strassen算法)
代码如下,可以看看数据结构的定义,时间换空间。
publicclassMatrix{ privatefinalMatrix[]_matrixArray; privatefinalintn; privateintelement; publicMatrix(intn){ this.n=n; if(n!=1){ this._matrixArray=newMatrix[4]; for(inti=0;i<4;i++){ this._matrixArray[i]=newMatrix(n/2); } }else{ this._matrixArray=null; } } privateMatrix(intn,booleanneedInit){ this.n=n; if(n!=1){ this._matrixArray=newMatrix[4]; }else{ this._matrixArray=null; } } publicvoidset(inti,intj,inta){ if(n==1){ element=a; }else{ intsize=n/2; this._matrixArray[(i/size)*2+(j/size)].set(i%size,j%size,a); } } publicMatrixmulti(Matrixm){ Matrixresult=null; if(n==1){ result=newMatrix(1); result.set(0,0,(element*m.element)); }else{ result=newMatrix(n,false); result._matrixArray[0]=P5(m).add(P4(m)).minus(P2(m)).add(P6(m)); result._matrixArray[1]=P1(m).add(P2(m)); result._matrixArray[2]=P3(m).add(P4(m)); result._matrixArray[3]=P5(m).add(P1(m)).minus(P3(m)).minus(P7(m)); } returnresult; } publicMatrixadd(Matrixm){ Matrixresult=null; if(n==1){ result=newMatrix(1); result.set(0,0,(element+m.element)); }else{ result=newMatrix(n,false); result._matrixArray[0]=this._matrixArray[0].add(m._matrixArray[0]); result._matrixArray[1]=this._matrixArray[1].add(m._matrixArray[1]); result._matrixArray[2]=this._matrixArray[2].add(m._matrixArray[2]); result._matrixArray[3]=this._matrixArray[3].add(m._matrixArray[3]);; } returnresult; } publicMatrixminus(Matrixm){ Matrixresult=null; if(n==1){ result=newMatrix(1); result.set(0,0,(element-m.element)); }else{ result=newMatrix(n,false); result._matrixArray[0]=this._matrixArray[0].minus(m._matrixArray[0]); result._matrixArray[1]=this._matrixArray[1].minus(m._matrixArray[1]); result._matrixArray[2]=this._matrixArray[2].minus(m._matrixArray[2]); result._matrixArray[3]=this._matrixArray[3].minus(m._matrixArray[3]);; } returnresult; } protectedMatrixP1(Matrixm){ return_matrixArray[0].multi(m._matrixArray[1]).minus(_matrixArray[0].multi(m._matrixArray[3])); } protectedMatrixP2(Matrixm){ return_matrixArray[0].multi(m._matrixArray[3]).add(_matrixArray[1].multi(m._matrixArray[3])); } protectedMatrixP3(Matrixm){ return_matrixArray[2].multi(m._matrixArray[0]).add(_matrixArray[3].multi(m._matrixArray[0])); } protectedMatrixP4(Matrixm){ return_matrixArray[3].multi(m._matrixArray[2]).minus(_matrixArray[3].multi(m._matrixArray[0])); } protectedMatrixP5(Matrixm){ return(_matrixArray[0].add(_matrixArray[3])).multi(m._matrixArray[0].add(m._matrixArray[3])); } protectedMatrixP6(Matrixm){ return(_matrixArray[1].minus(_matrixArray[3])).multi(m._matrixArray[2].add(m._matrixArray[3])); } protectedMatrixP7(Matrixm){ return(_matrixArray[0].minus(_matrixArray[2])).multi(m._matrixArray[0].add(m._matrixArray[1])); } publicintget(inti,intj){ if(n==1){ returnelement; }else{ intsize=n/2; returnthis._matrixArray[(i/size)*2+(j/size)].get(i%size,j%size); } } publicvoiddisplay(){ for(inti=0;i总结
到此这篇关于使用java写的矩阵乘法的文章就介绍到这了,更多相关java矩阵乘法(Strassen算法)内容请搜索毛票票以前的文章或继续浏览下面的相关文章希望大家以后多多支持毛票票!