drbh HF Staff commited on
Commit
4a05275
·
verified ·
1 Parent(s): c623dea

Update benchmarks/benchmark.py

Browse files
Files changed (1) hide show
  1. benchmarks/benchmark.py +104 -104
benchmarks/benchmark.py CHANGED
@@ -216,107 +216,107 @@ class FlashMLACausalBenchmark(Benchmark):
216
  return _verify_mla_decode(self, causal=True)
217
 
218
 
219
- class FlashMLAVarlenBenchmark(Benchmark):
220
- seed: int = 42
221
-
222
- # Workload: small (3 sequences, max_seqlen=64)
223
- def setup_small(self):
224
- H, D = 8, 64
225
- seqlens = [32, 48, 64]
226
- total = sum(seqlens)
227
- self.q = torch.randn(total, H, D, device="cuda", dtype=torch.bfloat16)
228
- self.k = torch.randn(total, H, D, device="cuda", dtype=torch.bfloat16)
229
- self.v = torch.randn(total, H, D, device="cuda", dtype=torch.bfloat16)
230
- self.cu_seqlens = torch.tensor(
231
- [0] + list(torch.cumsum(torch.tensor(seqlens), 0)),
232
- device="cuda",
233
- dtype=torch.int32,
234
- )
235
- self.max_seqlen = max(seqlens)
236
- self.out = torch.empty(total, H, D, device="cuda", dtype=torch.bfloat16)
237
-
238
- def benchmark_small(self):
239
- self.out = _extract_output(
240
- self.kernel.flash_attn_varlen_func(
241
- self.q,
242
- self.k,
243
- self.v,
244
- self.cu_seqlens,
245
- self.cu_seqlens,
246
- self.max_seqlen,
247
- self.max_seqlen,
248
- )
249
- )
250
-
251
- def verify_small(self) -> torch.Tensor:
252
- return _varlen_reference_attention(
253
- self.q, self.k, self.v, self.cu_seqlens, self.cu_seqlens, causal=False
254
- )
255
-
256
- # Workload: medium (5 sequences, max_seqlen=256)
257
- def setup_medium(self):
258
- H, D = 16, 64
259
- seqlens = [128, 192, 256, 200, 150]
260
- total = sum(seqlens)
261
- self.q = torch.randn(total, H, D, device="cuda", dtype=torch.bfloat16)
262
- self.k = torch.randn(total, H, D, device="cuda", dtype=torch.bfloat16)
263
- self.v = torch.randn(total, H, D, device="cuda", dtype=torch.bfloat16)
264
- self.cu_seqlens = torch.tensor(
265
- [0] + list(torch.cumsum(torch.tensor(seqlens), 0)),
266
- device="cuda",
267
- dtype=torch.int32,
268
- )
269
- self.max_seqlen = max(seqlens)
270
- self.out = torch.empty(total, H, D, device="cuda", dtype=torch.bfloat16)
271
-
272
- def benchmark_medium(self):
273
- self.out = _extract_output(
274
- self.kernel.flash_attn_varlen_func(
275
- self.q,
276
- self.k,
277
- self.v,
278
- self.cu_seqlens,
279
- self.cu_seqlens,
280
- self.max_seqlen,
281
- self.max_seqlen,
282
- )
283
- )
284
-
285
- def verify_medium(self) -> torch.Tensor:
286
- return _varlen_reference_attention(
287
- self.q, self.k, self.v, self.cu_seqlens, self.cu_seqlens, causal=False
288
- )
289
-
290
- # Workload: large (8 sequences, max_seqlen=512)
291
- def setup_large(self):
292
- H, D = 32, 128
293
- seqlens = [256, 384, 512, 448, 320, 480, 400, 512]
294
- total = sum(seqlens)
295
- self.q = torch.randn(total, H, D, device="cuda", dtype=torch.bfloat16)
296
- self.k = torch.randn(total, H, D, device="cuda", dtype=torch.bfloat16)
297
- self.v = torch.randn(total, H, D, device="cuda", dtype=torch.bfloat16)
298
- self.cu_seqlens = torch.tensor(
299
- [0] + list(torch.cumsum(torch.tensor(seqlens), 0)),
300
- device="cuda",
301
- dtype=torch.int32,
302
- )
303
- self.max_seqlen = max(seqlens)
304
- self.out = torch.empty(total, H, D, device="cuda", dtype=torch.bfloat16)
305
-
306
- def benchmark_large(self):
307
- self.out = _extract_output(
308
- self.kernel.flash_attn_varlen_func(
309
- self.q,
310
- self.k,
311
- self.v,
312
- self.cu_seqlens,
313
- self.cu_seqlens,
314
- self.max_seqlen,
315
- self.max_seqlen,
316
- )
317
- )
318
-
319
- def verify_large(self) -> torch.Tensor:
320
- return _varlen_reference_attention(
321
- self.q, self.k, self.v, self.cu_seqlens, self.cu_seqlens, causal=False
322
- )
 
216
  return _verify_mla_decode(self, causal=True)
217
 
218
 
219
+ # class FlashMLAVarlenBenchmark(Benchmark):
220
+ # seed: int = 42
221
+
222
+ # # Workload: small (3 sequences, max_seqlen=64)
223
+ # def setup_small(self):
224
+ # H, D = 8, 64
225
+ # seqlens = [32, 48, 64]
226
+ # total = sum(seqlens)
227
+ # self.q = torch.randn(total, H, D, device="cuda", dtype=torch.bfloat16)
228
+ # self.k = torch.randn(total, H, D, device="cuda", dtype=torch.bfloat16)
229
+ # self.v = torch.randn(total, H, D, device="cuda", dtype=torch.bfloat16)
230
+ # self.cu_seqlens = torch.tensor(
231
+ # [0] + list(torch.cumsum(torch.tensor(seqlens), 0)),
232
+ # device="cuda",
233
+ # dtype=torch.int32,
234
+ # )
235
+ # self.max_seqlen = max(seqlens)
236
+ # self.out = torch.empty(total, H, D, device="cuda", dtype=torch.bfloat16)
237
+
238
+ # def benchmark_small(self):
239
+ # self.out = _extract_output(
240
+ # self.kernel.flash_attn_varlen_func(
241
+ # self.q,
242
+ # self.k,
243
+ # self.v,
244
+ # self.cu_seqlens,
245
+ # self.cu_seqlens,
246
+ # self.max_seqlen,
247
+ # self.max_seqlen,
248
+ # )
249
+ # )
250
+
251
+ # def verify_small(self) -> torch.Tensor:
252
+ # return _varlen_reference_attention(
253
+ # self.q, self.k, self.v, self.cu_seqlens, self.cu_seqlens, causal=False
254
+ # )
255
+
256
+ # # Workload: medium (5 sequences, max_seqlen=256)
257
+ # def setup_medium(self):
258
+ # H, D = 16, 64
259
+ # seqlens = [128, 192, 256, 200, 150]
260
+ # total = sum(seqlens)
261
+ # self.q = torch.randn(total, H, D, device="cuda", dtype=torch.bfloat16)
262
+ # self.k = torch.randn(total, H, D, device="cuda", dtype=torch.bfloat16)
263
+ # self.v = torch.randn(total, H, D, device="cuda", dtype=torch.bfloat16)
264
+ # self.cu_seqlens = torch.tensor(
265
+ # [0] + list(torch.cumsum(torch.tensor(seqlens), 0)),
266
+ # device="cuda",
267
+ # dtype=torch.int32,
268
+ # )
269
+ # self.max_seqlen = max(seqlens)
270
+ # self.out = torch.empty(total, H, D, device="cuda", dtype=torch.bfloat16)
271
+
272
+ # def benchmark_medium(self):
273
+ # self.out = _extract_output(
274
+ # self.kernel.flash_attn_varlen_func(
275
+ # self.q,
276
+ # self.k,
277
+ # self.v,
278
+ # self.cu_seqlens,
279
+ # self.cu_seqlens,
280
+ # self.max_seqlen,
281
+ # self.max_seqlen,
282
+ # )
283
+ # )
284
+
285
+ # def verify_medium(self) -> torch.Tensor:
286
+ # return _varlen_reference_attention(
287
+ # self.q, self.k, self.v, self.cu_seqlens, self.cu_seqlens, causal=False
288
+ # )
289
+
290
+ # # Workload: large (8 sequences, max_seqlen=512)
291
+ # def setup_large(self):
292
+ # H, D = 32, 128
293
+ # seqlens = [256, 384, 512, 448, 320, 480, 400, 512]
294
+ # total = sum(seqlens)
295
+ # self.q = torch.randn(total, H, D, device="cuda", dtype=torch.bfloat16)
296
+ # self.k = torch.randn(total, H, D, device="cuda", dtype=torch.bfloat16)
297
+ # self.v = torch.randn(total, H, D, device="cuda", dtype=torch.bfloat16)
298
+ # self.cu_seqlens = torch.tensor(
299
+ # [0] + list(torch.cumsum(torch.tensor(seqlens), 0)),
300
+ # device="cuda",
301
+ # dtype=torch.int32,
302
+ # )
303
+ # self.max_seqlen = max(seqlens)
304
+ # self.out = torch.empty(total, H, D, device="cuda", dtype=torch.bfloat16)
305
+
306
+ # def benchmark_large(self):
307
+ # self.out = _extract_output(
308
+ # self.kernel.flash_attn_varlen_func(
309
+ # self.q,
310
+ # self.k,
311
+ # self.v,
312
+ # self.cu_seqlens,
313
+ # self.cu_seqlens,
314
+ # self.max_seqlen,
315
+ # self.max_seqlen,
316
+ # )
317
+ # )
318
+
319
+ # def verify_large(self) -> torch.Tensor:
320
+ # return _varlen_reference_attention(
321
+ # self.q, self.k, self.v, self.cu_seqlens, self.cu_seqlens, causal=False
322
+ # )