#!/usr/bin/python

import os, sys
import subprocess
import re
import math

import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
from matplotlib.ticker import FuncFormatter
import numpy as np

import itertools

import pdb

def add_event(dictionary, scale):
    if scale in dictionary:
        dictionary[scale] = dictionary[scale] + 1
    else:
        dictionary[scale] = 1

def convert(num):
    if num >= 1048576:
        return '{} MB'.format((num / 1048576))
    elif num >= 1024:
        return '{} kB'.format((num / 1024))
    # We start from 1 - 32 bytes anyway
    elif num == 17:
        return '1 B'
    else:
        return '{} B'.format(num)

def megabytes(x, pos):
    return '%1.f MB' % (x / 1048576)

def main(arg1):

    print('Checking how many memory events are in the log...'),
    sys.stdout.flush()
    command = 'cat {} | grep "^dmmlib - ms all" | wc -l'.format(arg1)
    num_events = int(subprocess.check_output(command, shell=True))
    # FIXME exit if there are no events
    print('{}.'.format(num_events))

    sampling = False
    points_to_draw = 50000
    if num_events >= points_to_draw:
        sampling = True
        sample_pack = int(math.ceil(float(num_events) / points_to_draw))
    req_sampling_check = 0
    alloc_sampling_check = 0

    # Example: dmmlib - m 0x7f5817829178 26
    malloc_pattern = re.compile(r'^dmmlib - m (0x\w+) (\d+)')
    # Example: dmmlib - ms req 609
    requested_pattern = re.compile('^dmmlib - ms req (\d+)')
    # Example: dmmlib - ms all 1280
    allocated_pattern = re.compile('^dmmlib - ms all (\d+)')

    scale_dictionary = {}
    requested = []
    allocated = []

    print('Parsing the log file...')
    sys.stdout.flush()
    with open(arg1) as infile:
        for line in infile:

            malloc_match = malloc_pattern.match(line)
            if malloc_match:
                request_size = int(malloc_match.group(2))
                scale = int(math.ceil(math.log(request_size, 2)))
                # Starting bar is up to 32 bytes
                if scale < 5:
                    scale = 5
                add_event(scale_dictionary, scale)

            requested_match = requested_pattern.match(line)
            if requested_match:
                if req_sampling_check == 0:
                        requested.append(int(requested_match.group(1)))
                        if sampling:
                            req_sampling_check += 1
                else:
                    if req_sampling_check == sample_pack - 1:
                        req_sampling_check = 0
                    else:
                        req_sampling_check += 1

            allocated_match = allocated_pattern.match(line)
            if allocated_match:
                if alloc_sampling_check == 0:
                        allocated.append(int(allocated_match.group(1)))
                        if sampling:
                            alloc_sampling_check += 1
                else:
                    if alloc_sampling_check == sample_pack - 1:
                        alloc_sampling_check = 0
                    else:
                        alloc_sampling_check += 1
    print('Done')

    # Histogram of requested sizes - START
    print('Creating a histogram plot for the requested sizes...'),
    sys.stdout.flush()

    hist_fig = plt.figure()    
    hist_ax = hist_fig.add_subplot(111)

    size_data = zip(*scale_dictionary.items())
    hist_rects = hist_ax.bar(np.array(size_data[0]) - 4, size_data[1], width=0.2, align='center')

    sorted_keys = np.array(sorted(scale_dictionary.keys()))
    start = np.power(2, sorted_keys - 1) + 1
    end = np.power(2, sorted_keys)

    xlabels = [ '{} - {}'.format(convert(t[0]), convert(t[1])) for t in itertools.izip(start, end) ]

    hist_ax.set_xticks(sorted_keys - 4)
    hist_ax.set_xticklabels(xlabels)
    # Show only the xticks of the bottom X axis
    hist_ax.get_xaxis().tick_bottom()
    
    hist_ax.set_title('Memory request sizes')
    hist_ax.set_ylabel('Counts')

    hist_fig.autofmt_xdate()
        
    hist_fig.savefig('{}_histogram.pdf'.format(arg1))

    print('Done.')
    # Histogram of requested sizes - END

    # Memory usage - START
    print('Creating a plot for the memory utilization timeline...'),
    sys.stdout.flush()

    req_fig = plt.figure()
    req_ax = req_fig.add_subplot(111)

    req_ax.plot(requested, label='requested memory', rasterized=True)
    req_ax.plot(allocated, label='allocated memory', rasterized=True)

    req_ax.set_xticks([])
    req_ax.set_xlabel('Application timeline')

    req_ax.yaxis.set_major_formatter(FuncFormatter(megabytes))
    req_ax.get_yaxis().tick_left()
    
    req_ax.legend(bbox_to_anchor=(0, 1.02, 1., .102), loc=3, ncol=2,
            mode="expand", borderaxespad=0.)

    req_fig.savefig('{}_mem_usage.pdf'.format(arg1), dpi=150)
    print('Done.')
    # Memory usage - END

    return 0

if __name__=='__main__':
    if len(sys.argv) < 2:
        sys.exit('Usage: %s trace-file' % sys.argv[0])
    if not os.path.exists(sys.argv[1]):
        sys.exit('ERROR: Trace file %s was not found!' % sys.argv[1])
    sys.exit(main(sys.argv[1]))