2023-03-29 06:47:08 +00:00
import matplotlib . pyplot as plt
import pandas as pd
import numpy as np
import os
import matplotlib . gridspec as gridspec
cmap = plt . get_cmap ( ' cool ' )
if __name__ == ' __main__ ' :
fig = plt . figure ( tight_layout = True , figsize = ( 12 , 3.5 ) )
gs = gridspec . GridSpec ( 1 , 2 )
2023-04-01 18:46:04 +00:00
dims_to_consider = [ 1024 , 1280 , 1408 , 1664 , 2048 , 4096 ]
batch_size_for_plot1 = 32768
batch_sizes_for_plot2 = [ 2 * * 14 , 2 * * 15 , 2 * * 16 , 2 * * 17 ]
dims_to_xtick = [ 1024 , 2048 , 4096 ]
logscale_plot1 = True
2023-03-29 06:47:08 +00:00
ax = fig . add_subplot ( gs [ 0 , 0 ] )
2023-04-01 23:50:12 +00:00
# TODO: change this to what you want.
2023-04-01 18:46:04 +00:00
rdf = pd . read_json ( ' speed_benchmark/info_a100_py2.jsonl ' , lines = True )
df = rdf [ rdf . batch_size == batch_size_for_plot1 ]
2023-03-29 06:47:08 +00:00
2023-04-01 18:46:04 +00:00
# first plot the time occupied by different operations
2023-03-29 06:47:08 +00:00
for k , marker , ls , color , name in [
( ' standard_gx+standard_gw+standard_fwd ' , ' s ' , ' - ' , ' C2 ' , ' Standard fp16 (sum of parts) ' ) ,
( ' x_quantize_rowwise+g_quantize_rowwise+w_quantize_global+w_quantize_global_transpose+standard_gw+global_fwd+global_bwd ' , ' o ' , ' - ' , ' C4 ' , ' SwitchBack int8 (sum of parts) ' ) ,
( ' standard_fwd ' , ' ^ ' , ' -- ' , ' C2 ' , ' Matmul XW (standard) ' ) ,
( ' standard_gw ' , ' ^ ' , ' -. ' , ' C2 ' , ' Matmul GW (standard) ' ) ,
( ' standard_gx ' , ' ^ ' , ' : ' , ' gray ' , ' Matmul GX (both) ' ) ,
( ' global_fwd ' , ' ^ ' , ' -- ' , ' C4 ' , ' Int8 Matmul XW (switchback) ' ) ,
( ' global_bwd ' , ' ^ ' , ' -. ' , ' C4 ' , ' Int8 Matmul GW (switchback) ' ) ,
( ' x_quantize_rowwise ' , ' P ' , ' -- ' , ' C4 ' , ' Quantize rowwise X (switchback) ' ) ,
( ' g_quantize_rowwise ' , ' P ' , ' -. ' , ' C4 ' , ' Quantize rowwise G (switchback) ' ) ,
( ' w_quantize_global ' , ' . ' , ' -- ' , ' C4 ' , ' Quatnize global W (switchback) ' ) ,
( ' w_quantize_global_transpose ' , ' . ' , ' -. ' , ' C4 ' , ' Quantize gloabl and \n transpose W (switchback) ' ) ,
] :
xs = [ ]
ys = [ ]
2023-04-01 18:46:04 +00:00
for embed_dim in dims_to_consider :
# average over dim -> 4*dim and 4*dim -> dim
2023-03-29 06:47:08 +00:00
df_ = df [ df . dim_in == embed_dim ]
df_ = df_ [ df_ . dim_out == embed_dim * 4 ]
xs . append ( embed_dim )
y_ = 0
for k_ in k . split ( ' + ' ) :
y_ + = df_ [ k_ ] . values [ 0 ]
df_ = df [ df . dim_in == embed_dim * 4 ]
df_ = df_ [ df_ . dim_out == embed_dim ]
for k_ in k . split ( ' + ' ) :
y_ + = df_ [ k_ ] . values [ 0 ]
ys . append ( y_ * 0.5 )
ax . plot ( xs , ys , color = color , label = name , marker = marker , markersize = 5 if marker == ' s ' else 5 , linestyle = ls , linewidth = 2 if ' + ' in k else 1. )
ax . set_xlabel ( ' dim ' , fontsize = 13 )
ax . set_ylabel ( ' time (ms) ' , fontsize = 13 )
ax . grid ( )
ax . set_xscale ( ' log ' )
2023-04-01 18:46:04 +00:00
if logscale_plot1 :
ax . set_yscale ( ' log ' )
2023-03-29 06:47:08 +00:00
ax . tick_params ( axis = ' x ' , labelsize = 11 )
ax . tick_params ( axis = ' y ' , labelsize = 11 )
2023-04-01 18:46:04 +00:00
ax . set_xticks ( dims_to_xtick )
ax . set_xticklabels ( dims_to_xtick )
2023-03-29 06:47:08 +00:00
ax . set_xticks ( [ ] , minor = True )
leg = ax . legend ( loc = ' upper center ' , bbox_to_anchor = ( - 0.64 , 1. ) , ncol = 1 , fontsize = 10 )
leg . get_texts ( ) [ 0 ] . set_fontweight ( ' bold ' )
leg . get_texts ( ) [ 1 ] . set_fontweight ( ' bold ' )
plt . subplots_adjust ( left = 0.1 )
ax . set_title ( ' Linear layer, batch * sequence length = 32k ' , fontsize = 10 , loc = ' left ' , y = 1.05 , pad = - 20 )
ax = fig . add_subplot ( gs [ 0 , 1 ] )
# now plot the % speedup for different batch sizes
2023-04-01 18:46:04 +00:00
for j , batch_size in enumerate ( batch_sizes_for_plot2 ) :
2023-03-29 06:47:08 +00:00
all_xs , all_ys = [ ] , [ ]
for k , marker , ls , color , name in [
( ' standard_gx+standard_gw+standard_fwd ' , ' s ' , ' - ' , ' C2 ' , ' Standard fp16 (total time) ' ) ,
( ' x_quantize_rowwise+g_quantize_rowwise+w_quantize_global+w_quantize_global_transpose+standard_gw+global_fwd+global_bwd ' , ' o ' , ' - ' , ' C4 ' , ' SwitchBack int8 (total time) ' ) ,
] :
xs , ys = [ ] , [ ]
df = rdf [ rdf . batch_size == batch_size ]
2023-04-01 18:46:04 +00:00
for embed_dim in dims_to_consider :
2023-03-29 06:47:08 +00:00
df_ = df [ df . dim_in == embed_dim ]
df_ = df_ [ df_ . dim_out == embed_dim * 4 ]
xs . append ( embed_dim )
y_ = 0
for k_ in k . split ( ' + ' ) :
y_ + = df_ [ k_ ] . values [ 0 ]
df_ = df [ df . dim_in == embed_dim * 4 ]
df_ = df_ [ df_ . dim_out == embed_dim ]
for k_ in k . split ( ' + ' ) :
y_ + = df_ [ k_ ] . values [ 0 ]
ys . append ( y_ * 0.5 )
all_xs . append ( xs )
all_ys . append ( ys )
color = cmap ( j * 0.25 )
real_ys = [ - ( ( all_ys [ 1 ] [ i ] - all_ys [ 0 ] [ i ] ) / all_ys [ 0 ] [ i ] ) * 100 for i in range ( len ( all_ys [ 0 ] ) ) ]
markers = [ ' ^ ' , ' v ' , ' P ' , ' o ' ]
ax . plot ( all_xs [ 0 ] , real_ys , color = color , label = f ' batch * sequence length = { batch_size } ' , marker = markers [ j ] , markersize = 5 if marker == ' s ' else 5 )
ax . legend ( )
ax . set_xlabel ( ' dim ' , fontsize = 13 )
ax . set_xscale ( ' log ' )
ax . grid ( )
ax . set_ylabel ( r ' % s peedup ' , fontsize = 13 )
ax . tick_params ( axis = ' x ' , labelsize = 11 )
ax . tick_params ( axis = ' y ' , labelsize = 11 )
2023-04-01 18:46:04 +00:00
ax . set_xticks ( dims_to_xtick )
ax . set_xticklabels ( dims_to_xtick )
2023-03-29 06:47:08 +00:00
ax . set_xticks ( [ ] , minor = True )
ax . set_title ( ' Linear layer summary, varying dimensions ' , fontsize = 10 , loc = ' left ' , y = 1.05 , pad = - 20 )
2023-04-01 18:46:04 +00:00
plt . savefig ( ' speed_benchmark/plot_with_info.pdf ' , bbox_inches = ' tight ' )
2023-03-29 06:47:08 +00:00