Dynamic Programming – 5 : Matrix chain multiplication
February 3, 2011 Leave a comment
Matrix chain mnultiplication
We are given a sequence of matrices , such that matrix has dimensions . Our task is to compute an optimal parenthesization of this sequence of matrices such that the sequence is fully parenthesized.
Terminology: A sequence of matrices is fully parenthesized if the sequence is either a single matrix or a product of 2 fully parenthesized sequences.
Optimality here refers to minimization of the number of scalar multiplications involved in computing the product of the matrix sequence.
// C code for multiplying A (m x n) with B (n x p). Product is C (m x p).
for (i=0; i < m; i++)
for (j=0; j < p; j++)
C[i][j] = 0;
for (k=0; k<n; k++)
C[i][j] + = A[i][k] * B[k][j];
//End of code.
Note 2: From the above, we can see that multiplying an (m x n) matrix with a (n x p) matrix requires m x n x p scalar multiplications.
Note 3: Matrix multiplication is associative. So, we can parenthesize the given sequence of matrices whichever way we want as long as we do not disturb the sequence itself.
If n > 1, we know that the required parenthesization divides the matrix sequence into 2 smaller fully parenthesized sequences.
Let this division occur at matrix , i.e. the number of scalar multiplications in equals the sum of the number of scalar multiplications in and plus the value .
Notation: Let denote the number of scalar multiplications in the optimal parenthesization of the sequence .
From the above, we have that:
Generalizing this, we can say that, for :
Further, for all possible .
Pseudocode for computing optimum number of scalar multiplications:
For i from 1 to n:
For j from 1 to (n-i):
Compute OPT [i,j] according to the above recurrence.
Set Soln[i,j] = k; // k is the index that optimizes OPT[i,j] in the recurrence.
End of algorithm.
The above code will take to find (in the process, computing the OPT[i,j] and Soln[i,j] table values).
Computing the optimal solution i.e. parenthesization
We can do so using the Soln [i,j] table.
Function: Print_Solution (i, j)
If ( i == j)
Print_Solution (i, Soln[i,j] );
Print_Solution ( Soln[i,j] + 1, j);
This function takes time for the call Print_Solution (1, n).