稀疏矩阵乘法问题!
程序代码:
#define MAX 9
typedef struct {
int row;
int col;
int value;
}term;
term a[MAX]={{6,6,8},{0,0,15},{0,3,22},{0,5,-15},{1,1,11},{1,2,3},{2,3,-6},{4,0,91},{5,2,28}};
term b[MAX]={{6,6,8},{1,0,34},{3,2,15},{1,5,10},{1,2,4},{4,2,3},{5,3,-6},{4,0,41},{5,2,30}};
void mmult(term a[],term b[],term d[])//d是相乘后的矩阵
{
int i,j,column,totalb=b[0].value,totald=0;
int rows_a=a[0].row,cols_a=a[0].col,
totala=a[0].value;
//printf("%d",totala);
int cols_b=b[0].col;
int row_begin=1,row=a[1].row,sum=0;
term new_b[MAX];//书上是这样写的:int new_b[MAX][3],这不是明显的错误吗?
if(cols_a!=b[0].row)
{
fprintf(stderr,"Incompatible matrices\n");
exit(1);
}
transpose(b,new_b);
/*prints_matrix_test(b);
printf("\n");
prints_matrix_test(new_b);
print(new_b);*/ //这一段只是测试转置是否成功
a[totala+1].row=rows_a;//
new_b[totalb+1].row=cols_b;//
new_b[totalb+1].col=0;//很奇怪这三段代码不会报错?这三段代码有什么用处?看不出~~!
for(i=1;i<=totala;)//下面的注释都是自己写上去的,貌似思想是对的,可是结果不对!
{
column=new_b[1].row;//获取b的列
for(j=1;j<=totalb+1;)// 遍历列
{
if(a[i].row!=row)//如果遍历到的行与当前所在行不一致
{
storesum(d,&totald,row,column,&sum);//将当前行列元素写入d中
i=row_begin;//重置当前行
for(;new_b[j].row==column;j++)//遍历至下一列
{
;
}
column=new_b[j].row;//将下一列作为当前列
}
else if(new_b[j].row!=column)//如果遍历到的列与当前所在列不一致
{
storesum(d,&totald,row,column,&sum);//将当前行列元素写入d中
i=row_begin;//重置当前行
column=new_b[j].row;//重置当前列
}
else
{
switch(COMPARE(a[i].col,new_b[j].col))//比较a当前列的下标与b当前行的下标的大小
{
case -1://列下标比行下标小,列下标移至下一列
i++;break;
case 0:
sum+=(a[i++].value*new_b[j++].value);
break;
case 1://行下标比列下标小,行下标移至下一行
j++;break;
}
}
}
for(;a[i].row==row;i++)//a矩阵跳到下一行
{
;
}
row_begin=i;
row=a[i].row;
}
d[0].row=rows_a;
d[0].col=cols_b;
d[0].value=totald;
}
void storesum(term d[],int *totald,int row,int column,int *sum)
{
if(*sum)
{
if((*totald)<MAX)
{
d[++*totald].row=row;
d[*totald].col=column;
d[*totald].value=*sum;
}
else
{
fprintf(stderr,"Numbers of terms in product exceeds %d\n",MAX);
exit(1);
}
}
}
int COMPARE(int a,int b)
{
if(a<b)
return -1;
else if(a==b)
return 0;
else
return 1;
}
下面附上测试程序:
程序代码:void transpose(term a[],term b[])//转置函数
{
int row_terms[7],startingpos[7];
int i,j,num_cols=a[0].col,num_terms=a[0].value;
b[0].row=num_cols;b[0].col=a[0].row;
b[0].value=num_terms;
if(num_terms>0)
{
for(i=0;i<num_cols;i++)
{
row_terms[i]=0;//将行中元素个数置为0
}
for(i=1;i<=num_terms;i++)
{
row_terms[a[i].col]++;//记录行中非零元素的个数
}
startingpos[0]=1;
for(i=1;i<num_cols;i++)
{
startingpos[i]=startingpos[i-1]+row_terms[i-1];
}
for(i=1;i<=num_terms;i++)
{
j=startingpos[a[i].col]++;
b[j].row=a[i].col;
b[j].col=a[i].row;
b[j].value=a[i].value;
}
printf("\n");
}
}
void prints_matrix_test(term *arry)//打印矩阵函数
{
int i,j;
int k=1;
for(i=0;i<(*arry).row;++i)
{
for(j=0;j<(*arry).col;++j)
{
while(k<MAX)//通过循环遍历到符合位置的元素并打印
{
if(i==(*(arry+k)).row&&j==(*(arry+k)).col)
{
printf("%4d",(*(arry+k)).value);
break;
}
++k;
}
if(k==MAX)
{
printf("%4d",0);
}
k=1;//k重置,为下一次循环做准备
}
printf("\n");
}
}
int print(term arry[])//打印三元组函数
{
int i;
for(i=0;i<MAX;++i)
{
printf("%d %d %d \n",arry[i].row,arry[i].col,arry[i].value);
}
return 0;
}




