SoftMax硬件设计

  1. 1. SoftMax
    1. 1.1. SoftMax量化
    2. 1.2. SoftMax硬件实现
      1. 1.2.1. SoftMax_FindMax.sv
    3. 1.3. SoftMax硬件架构以及位宽设置

SoftMax

首先是SoftMax的原始公式

在实际计算时,所有的输入数据需要减去一个最大值防止溢出(减去最大值这个操作对SoftMax的计算没有任何影响),于是SoftMax变成如下公式:

SoftMax量化

SoftMax的量化参考2022年的论文Fq-Vit,虽然是第一篇专门针对Vit进行量化的论文,但是这篇论文基本上是将I-Bert移植到Fq-Vit上了,而且没有对Gelu进行量化,所以我只能用Relu替代Gelu~

SoftMax硬件实现

SoftMax_FindMax.sv

首先需要明确SoftMax的输入是一个按行优先排列的矩阵,至少在Vit中的注意力机制计算时,输入的是一个197*197大小的方阵。由于采用的数据排列策略是行优先,所以需要先遍历一行的输入,找到这一行数据的最大值,然后再重新用这一行的数据减去最大值。
因此,首先需要设计一条二级流水线,寻找最大值以及对输入数据进行缓存。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
module Stage1_FindMax#(parameter BRAM_DEPTH =197,parameter DATA_WIDTH=64,parameter MAX_ROW=1024)(//现在一下进8个点,找最大值应该是按照8个点流水式地找
//Bram的深度:200

input clk,
input rstn,
input sValid,
input [DATA_WIDTH-1:0]sData,//这里输入的Xq是Xq-Zp后的值
input wire sLast,
output logic sReady,
input [$clog2(BRAM_DEPTH)-1:0]Matrix_Col,//矩阵列数,若列数不是8的倍数,上层模块应该进行补零操作
input [$clog2(MAX_ROW)-1:0]Matrix_Row,//矩阵列数,若列数不是8的倍数,上层模块应该进行补零操作
output logic [DATA_WIDTH-1:0]mData,//最大值减去Xq
output logic mValid,
output logic mLast
);
genvar i;
// for (i=0;i<8;i=i+1)begin
// wire [7:0]sData_Test;
// assign sData_Test=sData[i*8+:8];
// end
// 向上取整
logic [$clog2(BRAM_DEPTH)-1:0]count_times;
logic remainder;
assign remainder=|Matrix_Col[$clog2(DATA_WIDTH)-1:0];
assign count_times=(Matrix_Col>>$clog2(DATA_WIDTH/8))+remainder;//todo:下面这种表示写法有问题,出来的结果是12,,运算符优先级的问题吧
// assign count_times=Matrix_Col>>$clog2(DATA_WIDTH/8)+remainder;

reg [$clog2(MAX_ROW)-1:0]Row_Cnt;//行计数器
logic [$clog2(BRAM_DEPTH)-1:0]Col_Cnt;//列计数器
logic signed[7:0]Max;
logic signed[7:0]Last_Max;
logic Col_Cnt_Valid;
assign Col_Cnt_Valid =(Col_Cnt==count_times-1&&sValid)||Col_Cnt==count_times-1&&Row_Cnt==Matrix_Row;//后面这一个条件是为了处理边界问题


always_ff@(posedge clk or negedge rstn) begin
if(!rstn)begin
Col_Cnt<=0;
end
else if(Col_Cnt_Valid)begin
Col_Cnt<=0;
end
else if(sValid&sReady||(Row_Cnt==Matrix_Row)&sReady)Col_Cnt<=Col_Cnt+1'b1;//最后一行的时候也要读取数据
end

//构建一个多级流水的比较器

reg signed [7:0]Max_4[0:3];//对进来的8个点进行两两比较,取4个最大值

always_ff@(posedge clk or negedge rstn)begin
if (!rstn)
for (int i=0;i<4;i=i+1)
Max_4[i]<=0;
else begin

for (int i = 0; i <4; i = i+1)//需要注意这里不加signed的话,比较就按无符号比较,于是出现负数比正数大的情况
Max_4[i] <= $signed(sData[2*i*8+:8])<$signed(sData[(2*i+1)*8+:8])?sData[(2*i+1)*8+:8]:sData[2*i*8+:8];//>sData[7:0]?sData[7:0]:sData[7:0];//对比较结果打一拍
end
end
// wire[7:0]testttttt[0:3];
// for(i=0;i<4;i=i+1)assign testttttt[i]=sData[(2*i)*8+:8];
reg signed [7:0]Max_2[0:1];//对上一级的4个点继续进行两两比较,取2个最大值

always_ff@(posedge clk or negedge rstn)begin
if (!rstn)
// Max_2 <= '{default: '0};
for (int i=0;i<2;i=i+1)
Max_2[i]<=0;
else begin
Max_2[0] <= Max_4[0]<Max_4[1]?Max_4[1]:Max_4[0];
Max_2[1] <= Max_4[2]<Max_4[3]?Max_4[3]:Max_4[2];
end
end

logic signed [7:0]Max_1;
wire Max1_Valid;
always_ff@(posedge clk or negedge rstn)begin
if(!rstn)begin
Max_1<='d0;
end
else begin
Max_1<=Max_2[1]<Max_2[0]?Max_2[0]:Max_2[1];
end
end



parameter Max_Dly_Times=4;//对Max_valid 延时4拍
parameter Max1_Dly_Times=3;
reg [Max_Dly_Times-1:0]Max_Valid_Dly;
reg [Max1_Dly_Times-1:0]Max1_Valid_Dly;
assign Max1_Valid=Max1_Valid_Dly[Max1_Dly_Times-1];
wire Max_Valid;
assign Max_Valid=Max_Valid_Dly[Max_Dly_Times-1];

always_ff@(posedge clk or negedge rstn)begin
if(!rstn)begin
Max<='d0;
end
else if(Max_Valid)Max<=0;
else begin
Max<=(Max<Max_1)&&Max1_Valid?Max_1:Max;
end
end
always_ff@(posedge clk or negedge rstn)begin
if ((!rstn)|mLast)begin
Max_Valid_Dly<='d0;
Max1_Valid_Dly<='d0;
// Max_Valid<=1'b0;
end
else begin
Max_Valid_Dly<={Max_Valid_Dly[Max_Dly_Times-2:0],Col_Cnt_Valid};
Max1_Valid_Dly<={Max1_Valid_Dly[Max1_Dly_Times-2:0],sValid};
// Max_Valid<=Max_Valid_Dly[Max_Dly_Times-1];//这里不知道怎么回事,赋值位Max_Dly[3],VCS也没有报错。。
end
end

always_ff@(posedge clk or negedge rstn) begin
if(!rstn)begin
Last_Max<='b0;
end
else if(Max_Valid)begin
Last_Max<=Max;
end
end


logic IS_GET_ONE_COL;
always_ff @(posedge clk or negedge rstn) begin
if((!rstn)|mLast)begin
IS_GET_ONE_COL<=1'b0;
end
else if(Max_Valid)begin
IS_GET_ONE_COL<=1'b1;
end
end



always_ff @(posedge clk or negedge rstn)begin
if(!rstn)begin
sReady<=1'b1;
end

else if (Col_Cnt_Valid)sReady<=1'b0;//这样做的目的是让进来的数停一下,因为比较器中还有一些残留的数据没处理完
else if (Max_Valid|mLast)sReady<=1'b1;
end
//同时加一个缓存模块
//当进完一行数据后,之后就需要从MEM中取数据了
logic [DATA_WIDTH-1:0]Stage1_Mem[0:BRAM_DEPTH-1];
//每进一个数,就把它存下来

reg [$clog2(BRAM_DEPTH)-1:0]waddr;//写地址,写地址要慢都地址一拍,防止读写冲突
reg [DATA_WIDTH-1:0]wdata;
reg wen;//写使能
always_ff@(posedge clk or negedge rstn)begin
if (!rstn)begin
waddr<=0;
wdata<=0;
wen<=0;
end
else begin
waddr<=Col_Cnt;//Col_Cnt其实也是读地址,写地址永远慢读地址一拍
wdata<=sData;
wen<=sValid&&sReady;
end
end

always_ff @(posedge clk) begin
if(wen)begin
Stage1_Mem[waddr]<=wdata;//进一个数,存下来
end
end

//接下来的策略就是读快写一拍,每进一个数,先读再写,有Col_Cnt控制读写地址

for (i=0;i<DATA_WIDTH/8;i=i+1)begin
always_ff@(posedge clk or negedge rstn)begin
if (!rstn)mData[i*8+:8]<=0;//Last_Max-
else mData[i*8+:8]<=$signed(Stage1_Mem[Col_Cnt][i*8+:8])-$signed(Last_Max);//-Last_Max;
end
end

wire Row_Cnt_Valid;
assign Row_Cnt_Valid=(Row_Cnt==Matrix_Row)&&Col_Cnt_Valid;//Row_Cnt需要多数一行
always_ff@(posedge clk or negedge rstn)begin
if(!rstn)Row_Cnt<=0;
else if(Row_Cnt_Valid)Row_Cnt<=0;
else if(Col_Cnt_Valid)Row_Cnt<=Row_Cnt+1'b1;
end

// logic valid_next;
//还是把这里改成fifo好了
always_ff @(posedge clk or negedge rstn) begin
if(!rstn)begin
mValid<=0;
mLast<=0;
end
else begin
mLast<=Row_Cnt_Valid;
mValid<=(IS_GET_ONE_COL&sValid&sReady)||(Row_Cnt==Matrix_Row&sReady);//进一个数,往外面蹦一个数
end

end
endmodule

SoftMax硬件架构以及位宽设置