139 lines
5.1 KiB
Python
139 lines
5.1 KiB
Python
import matplotlib.pyplot as plt
|
|
import pandas as pd
|
|
import numpy as np
|
|
import os
|
|
|
|
import matplotlib.gridspec as gridspec
|
|
|
|
cmap=plt.get_cmap('cool')
|
|
|
|
if __name__ == '__main__':
|
|
|
|
fig = plt.figure(tight_layout=True, figsize=(12,3.5))
|
|
gs = gridspec.GridSpec(1, 2)
|
|
|
|
dims_to_consider = [1024, 1280, 1408, 1664, 2048, 4096]
|
|
batch_size_for_plot1 = 32768
|
|
batch_sizes_for_plot2 = [2**14, 2**15, 2**16, 2**17]
|
|
dims_to_xtick = [1024, 2048, 4096]
|
|
logscale_plot1 = True
|
|
|
|
ax = fig.add_subplot(gs[0, 0])
|
|
|
|
# TODO: change this to what you want.
|
|
rdf = pd.read_json('speed_benchmark/info_a100_py2.jsonl', lines=True)
|
|
df = rdf[rdf.batch_size == batch_size_for_plot1]
|
|
|
|
# first plot the time occupied by different operations
|
|
for k, marker, ls, color, name in [
|
|
('standard_gx+standard_gw+standard_fwd', 's', '-', 'C2', 'Standard fp16 (sum of parts)'),
|
|
('x_quantize_rowwise+g_quantize_rowwise+w_quantize_global+w_quantize_global_transpose+standard_gw+global_fwd+global_bwd', 'o', '-', 'C4', 'SwitchBack int8 (sum of parts)'),
|
|
|
|
('standard_fwd', '^', '--', 'C2', 'Matmul XW (standard)'),
|
|
('standard_gw', '^', '-.', 'C2', 'Matmul GW (standard)'),
|
|
('standard_gx', '^', ':', 'gray', 'Matmul GX (both)'),
|
|
|
|
('global_fwd', '^', '--', 'C4', 'Int8 Matmul XW (switchback)'),
|
|
('global_bwd', '^', '-.', 'C4', 'Int8 Matmul GW (switchback)'),
|
|
|
|
('x_quantize_rowwise', 'P', '--', 'C4', 'Quantize rowwise X (switchback)'),
|
|
('g_quantize_rowwise', 'P', '-.', 'C4', 'Quantize rowwise G (switchback)'),
|
|
('w_quantize_global', '.', '--', 'C4', 'Quatnize global W (switchback)'),
|
|
('w_quantize_global_transpose', '.', '-.', 'C4', 'Quantize gloabl and\ntranspose W (switchback)'),
|
|
]:
|
|
xs = []
|
|
ys = []
|
|
for embed_dim in dims_to_consider:
|
|
# average over dim -> 4*dim and 4*dim -> dim
|
|
df_ = df[df.dim_in == embed_dim]
|
|
df_ = df_[df_.dim_out == embed_dim * 4]
|
|
xs.append(embed_dim)
|
|
y_ = 0
|
|
for k_ in k.split('+'):
|
|
y_ += df_[k_].values[0]
|
|
df_ = df[df.dim_in == embed_dim * 4]
|
|
df_ = df_[df_.dim_out == embed_dim]
|
|
for k_ in k.split('+'):
|
|
y_ += df_[k_].values[0]
|
|
ys.append(y_ * 0.5)
|
|
|
|
|
|
ax.plot(xs, ys, color=color, label=name, marker=marker, markersize=5 if marker=='s' else 5, linestyle=ls, linewidth=2 if '+' in k else 1.)
|
|
|
|
|
|
ax.set_xlabel('dim', fontsize=13)
|
|
ax.set_ylabel('time (ms)', fontsize=13)
|
|
|
|
ax.grid()
|
|
|
|
ax.set_xscale('log')
|
|
if logscale_plot1:
|
|
ax.set_yscale('log')
|
|
|
|
ax.tick_params(axis='x', labelsize=11)
|
|
ax.tick_params(axis='y', labelsize=11)
|
|
|
|
ax.set_xticks(dims_to_xtick)
|
|
ax.set_xticklabels(dims_to_xtick)
|
|
ax.set_xticks([], minor=True)
|
|
|
|
leg = ax.legend(loc='upper center', bbox_to_anchor=(-0.64, 1.), ncol=1, fontsize=10)
|
|
leg.get_texts()[0].set_fontweight('bold')
|
|
leg.get_texts()[1].set_fontweight('bold')
|
|
plt.subplots_adjust(left=0.1)
|
|
ax.set_title(' Linear layer, batch * sequence length = 32k', fontsize=10, loc='left', y=1.05, pad=-20)
|
|
|
|
|
|
ax = fig.add_subplot(gs[0, 1])
|
|
|
|
# now plot the % speedup for different batch sizes
|
|
for j, batch_size in enumerate(batch_sizes_for_plot2):
|
|
all_xs, all_ys = [], []
|
|
for k, marker, ls, color, name in [
|
|
('standard_gx+standard_gw+standard_fwd', 's', '-', 'C2', 'Standard fp16 (total time)'),
|
|
('x_quantize_rowwise+g_quantize_rowwise+w_quantize_global+w_quantize_global_transpose+standard_gw+global_fwd+global_bwd', 'o', '-', 'C4', 'SwitchBack int8 (total time)'),
|
|
]:
|
|
|
|
xs, ys = [], []
|
|
df = rdf[rdf.batch_size == batch_size]
|
|
for embed_dim in dims_to_consider:
|
|
df_ = df[df.dim_in == embed_dim]
|
|
df_ = df_[df_.dim_out == embed_dim * 4]
|
|
xs.append(embed_dim)
|
|
y_ = 0
|
|
for k_ in k.split('+'):
|
|
y_ += df_[k_].values[0]
|
|
df_ = df[df.dim_in == embed_dim * 4]
|
|
df_ = df_[df_.dim_out == embed_dim]
|
|
for k_ in k.split('+'):
|
|
y_ += df_[k_].values[0]
|
|
ys.append(y_ * 0.5)
|
|
all_xs.append(xs)
|
|
all_ys.append(ys)
|
|
|
|
color = cmap(j * 0.25)
|
|
real_ys = [-((all_ys[1][i] - all_ys[0][i]) / all_ys[0][i]) * 100 for i in range(len(all_ys[0]))]
|
|
markers = ['^', 'v', 'P', 'o']
|
|
ax.plot(all_xs[0], real_ys, color=color, label=f'batch * sequence length = {batch_size}', marker=markers[j], markersize=5 if marker=='s' else 5)
|
|
|
|
ax.legend()
|
|
ax.set_xlabel('dim', fontsize=13)
|
|
ax.set_xscale('log')
|
|
ax.grid()
|
|
ax.set_ylabel(r'% speedup', fontsize=13)
|
|
|
|
|
|
ax.tick_params(axis='x', labelsize=11)
|
|
ax.tick_params(axis='y', labelsize=11)
|
|
|
|
ax.set_xticks(dims_to_xtick)
|
|
ax.set_xticklabels(dims_to_xtick)
|
|
ax.set_xticks([], minor=True)
|
|
|
|
ax.set_title(' Linear layer summary, varying dimensions', fontsize=10, loc='left', y=1.05, pad=-20)
|
|
|
|
|
|
|
|
plt.savefig('speed_benchmark/plot_with_info.pdf', bbox_inches='tight')
|
|
|