矩阵乘法的Strassen算法

  最简单的矩阵乘法可以通过三重循环来实现,其时间复杂度为Θ(n3)\Theta(n^{3}),Strassen算法通过巧妙的增加加法来减少乘法实现了O(n2.81)O(n^{2.81})的时间复杂度

Strassen算法的四个步骤:

  1. 将输入矩阵A、B与输出矩阵C分解为n/2×n/2n/2\times n/2的子矩阵,采用下标计算方法,此步骤花费Θ\Theta(1)时间。
  2. 创建10个n/2×n/2n/2\times n/2的矩阵,每个矩阵保存步骤1中创建的两个子矩阵的和或差,花费Θ(n2)\Theta(n^2)
  3. 用步骤1中创建的子矩阵和步骤2中创建的10个矩阵,递归的计算7个PiP_i矩阵积。
  4. 通过PiP_i矩阵的不同组合进行加减运算,计算出C的子矩阵,花费时间Θ(n2)\Theta(n^2)

  为了方便计算矩阵积C=A\cdotB,假定三个矩阵均为n×nn\times n矩阵,其中n为2的幂。做出这个假设是因为在每个分解步骤中,n×nn\times n矩阵都被划分为4个n/2×n/2n/2\times n/2的子矩阵,如果假定nn是2的幂,则只要n2n\geq 2即可保证子矩阵规模n/2n/2为整数。

假定将A、B和C均分解为4个n/2×n/2n/2\times n/2的子矩阵:

A=[A11A12A21A22]B=[B11B12B21B22]C=[C11C12C21C22]A= \left[ \begin{matrix} A_{11} & A_{12} \\ A_{21} & A_{22} \\ \end{matrix} \right] ,B= \left[ \begin{matrix} B_{11} & B_{12} \\ B_{21} & B_{22} \\ \end{matrix} \right] ,C= \left[ \begin{matrix} C_{11} & C_{12} \\ C_{21} & C_{22} \\ \end{matrix} \right]

根据矩阵乘法的定义,可以得到如下4个公式:

C11=A11B11+A12B21C12=A11B12+A12B22C21=A21B11+A22B21C22=A21B12+A22B22(1)\begin{aligned} C_{11}&=A_{11}\cdot B_{11}+A_{12}\cdot B_{21}\\ C_{12}&=A_{11}\cdot B_{12}+A_{12}\cdot B_{22}\\ C_{21}&=A_{21}\cdot B_{11}+A_{22}\cdot B_{21}\\ C_{22}&=A_{21}\cdot B_{12}+A_{22}\cdot B_{22}\\ \end{aligned}\left( 1 \right)

步骤2中,创建如下10个矩阵:

S1=B12B22S2=A11A12S3=A21+A22S4=B21B21S5=A11+A22S6=B11+B22S7=A12A22S8=B21+B22S9=A11A21S10=B11+B12(2)\begin{aligned} S_1&=B_{12}-B_{22}\\ S_2&=A_{11}-A_{12}\\ S_3&=A_{21}+A_{22}\\ S_4&=B_{21}-B_{21}\\ S_5&=A_{11}+A_{22}\\ S_6&=B_{11}+B_{22}\\ S_7&=A_{12}-A_{22}\\ S_8&=B_{21}+B_{22}\\ S_9&=A_{11}-A_{21}\\ S_{10}&=B_{11}+B_{12}\\ \end{aligned}\left( 2 \right)

由于必须进行10次n/2×n/2n/2\times n/2矩阵的加减法,因此,该步骤花费Θ(n2)\Theta(n^2)时间。
步骤3中,递归的计算7次n/2×n/2n/2\times n/2矩阵的乘法,如下所示:

P1=A11S1P2=S2B22P3=S3B11P4=A22S4P5=S5S6P6=S7S8P7=S9S10\begin{aligned} P_1&=A_{11}\cdot S_1\\ P_2&=S_2\cdot B_{22}\\ P_3&=S_3\cdot B_{11}\\ P_4&=A_{22}\cdot S_4\\ P_5&=S_5\cdot S_6\\ P_6&=S_7\cdot S_8\\ P_7&=S_9\cdot S_{10}\\ \end{aligned}

步骤4中,

C11=P5+P4P2+P6C12=P1+P2C21=P3+P4C22=P5+P1P3P7\begin{aligned} C_{11}&=P_5+P_4-P_2+P_6\\ C_{12}&=P_1+P_2\\ C_{21}&=P_3+P_4\\ C_{22}&=P_5+P_1-P_3-P_7\\ \end{aligned}

共进行了8次n/2×n/2n/2\times n/2矩阵的加减法,因此花费Θ(n2)\Theta(n^2)时间。
代值计算后可以发现(2)式结果与(1)式是相同的。

描述Strassen算法运行时间T(n)的递归式:

T(n)={Θ(1)n=17T(n/2)+Θ(n2)n>1T(n)=\begin{cases} \Theta(1)&n=1\\ 7T(n/2)+\Theta(n^2)&n>1\\ \end{cases}

  用主方法来求解这个递归式,可知解为T(n)=Θ(nlg7)T\left(n\right)=\Theta(n^{lg7}),由于lg7lg7介于2.80和2.81之间,所以时间复杂度为O(n2.81)O(n^{2.81})

天知道Strassen是怎么想到这个方法的QAQ