#!/usr/bin/env python

import nsysstats

class SyncMemset(nsysstats.Report):

    ROW_LIMIT = 50

    usage = f"""{{SCRIPT}}[:rows=<limit>] -- SyncMemset

    Options:
        rows=<limit> - Maximum number of rows returned by the query.
            Default is {ROW_LIMIT}.

    Output: All time values default to nanoseconds
        Duration : Duration of memset on GPU
        Start : Start time of memset on GPU
        Memory Kind : Type of memory being set
        Bytes : Number of bytes transferred
        PID : Process identifier
        Device ID : GPU device identifier
        Context ID : Context identifier
        Stream ID : Stream identifier
        API Name : Name of runtime API function

    This rule identifies synchronous memset operations with pinned host
    memory or Unified Memory region.
"""

    query_sync_memset = """
    WITH
        {MEM_KIND_STRS_CTE}
        sync AS (
            SELECT
                id,
                value
            FROM
                StringIds
            WHERE
                value LIKE 'cudaMemset%'
                AND value NOT LIKE '%async%'
        ),
        memset AS (
            SELECT
                *
            FROM
                CUPTI_ACTIVITY_KIND_MEMSET
            WHERE
                memKind = 1
                OR memKind = 4
        )
    SELECT
        memset.end - memset.start AS "Duration:dur_ns",
        memset.start AS "Start:ts_ns",
        mk.name AS "Memory Kind",
        bytes AS "Bytes:mem_B",
        (globalPid >> 24) & 0x00FFFFFF AS "PID",
        deviceId AS "Device ID",
        contextId AS "Context ID",
        streamId AS "Stream ID",
        sync.value AS "API Name",
        globalPid AS "_Global PID"
    FROM
        memset
    JOIN
        sync
        ON sync.id = runtime.nameId
    JOIN
        CUPTI_ACTIVITY_KIND_RUNTIME AS runtime
        ON runtime.correlationId = memset.correlationId
    LEFT JOIN
        MemKindStrs AS mk
        ON memKind = mk.id
    ORDER BY
        1 DESC
    LIMIT {ROW_LIMIT}
"""

    table_checks = {
        'CUPTI_ACTIVITY_KIND_RUNTIME':
            "{DBFILE} could not be analyzed because it does not contain CUDA trace data.",
        'CUPTI_ACTIVITY_KIND_MEMSET':
            "{DBFILE} could not be analyzed because it does not contain CUDA trace data."
    }

    def setup(self):
        err = super().setup()
        if err != None:
            return err

        row_limit = self.ROW_LIMIT
        for arg in self.args:
            s = arg.split('=')
            if len(s) == 2 and s[0] == 'rows' and s[1].isdigit():
                row_limit = s[1]
            else:
                exit(self.EXIT_INVALID_ARG)

        self.query = self.query_sync_memset.format(
            MEM_KIND_STRS_CTE = self.MEM_KIND_STRS_CTE,
            ROW_LIMIT = row_limit)

if __name__ == "__main__":
    SyncMemset.Main()
