
    PiGi                        d dl Z d dlZd dlZd dlZd dlmZ d dlmZ d dlm	Z	 d dl
mZ d dlmZ d dlmZ d dlmZ ej                            ej                            e                    Zej                            ed	          gZd
 Z e j                    d             Z G d de          Zd ZddddddZddddddZdZ d Z!d Z" G d de          Z# G d de          Z$dS )    N)Path)knobs)	GPUTarget)	GPUDriver)_allocation)compile_module_from_src)TensorDescriptorincludec                 :    dd l }|                                dk    rd S dd lddlm}m}m}mmm}  G fddj	                  }
                    | ||           ||           ||                    }	                     d          j        }n# t          $ r Y d S w xY w|g|_        ||_        d                    dz             }	 fd	}
 | ||
          |	          r't#          j                            |	                    S d S )
Nr   Linux)c_charc_intc_size_tc_void_pc_char_pPOINTERc                        e Zd Zdfd fgZdS )8_find_already_mmapped_dylib_on_linux.<locals>.DlPhdrInfo	dlpi_addr	dlpi_nameN)__name__
__module____qualname___fields_)r   r   s   r/var/www/development/aibuddy-work/election-extract/venv/lib/python3.11/site-packages/triton/backends/amd/driver.py
DlPhdrInfor      s&        (#(#
    r   z	libc.so.6i      c           
          | j         j        }t          t          j        |                    }|j        v r4                    ||t          t          |                               dS dS )Nr   r   )	contentsr   r   osfsdecodenamememmoveminlen)infosizedatar   pctypeslib_namemax_path_lengths        r   callbackz6_find_already_mmapped_dylib_on_linux.<locals>.callback3   sa    M+	Y''((qvNN4CY,P,PQQQ1qr   )platformsystemr+   r   r   r   r   r   r   	Structure	CFUNCTYPECDLLdl_iterate_phdr	Exceptionargtypesrestypecreate_string_bufferr!   r"   	string_at)r,   r/   r   r   r   r   r   
callback_tr4   pathr.   r   r   r+   r-   s   `          @@@@r   $_find_already_mmapped_dylib_on_linuxr<      s   OOOG##t
 MMMKKKKKKKKKKKKKKKK
 
 
 
 
 
 
 
V% 
 
 
 !!%)<)<ggh>O>OQXQXY_Q`Q`aaJ ++k22B   tt !+H5O#OO&&':;;D       zz(++T22 3{6++D112224s   B 
B+*B+c                  ~	   dt           j        j        x} rK|                               r!t          j                            |           r| S t          d|  d           t                    }|r6t          j                            |          r|S t          d| d           g }t          j        	                    t          j        
                    t                    d          }t          j                            |          r|S |                    |           dd l}|                                }|                                }|j        r|g|z   }|D ]\}t          j        	                    |dd          }t          j                            |          r|c S |                    |           ]t	          j        d	          }|rp|                    d
          D ]Z}	t          j        	                    |	          }
t          j                            |
          r|
c S |                    |
           [t	          j        d          }|rWt          j        	                    |d          }t          j                            |          r|S |                    |           	 t'          j        ddg                                                                          }|rWt          j        	                    |d          }t          j                            |          r|S |                    |           n# t&          j        t0          f$ r Y nw xY wt	          j        d          }|rWt          j        	                    |d          }t          j                            |          r|S |                    |           t'          j        ddg                              d          }fd|                                D             }|D ]:}t          j                            |          r|c S |                    |           ;t          j        	                    d          }t          j                            |          r|S |                    |           t          d d|           )Nzlibamdhip64.sozTRITON_LIBHIP_PATH 'z' does not point to a valid zmemory mapped 'z'' in process does not point to a valid libr   torchLD_LIBRARY_PATH:HIP_PATH	hipconfigz--path	ROCM_PATHz/sbin/ldconfigz-pignore)errorsc                     g | ]C}|                                                               )|                                d          DS ))stripendswithsplit).0liner,   s     r   
<listcomp>z2_get_path_to_hip_runtime_dylib.<locals>.<listcomp>   sB    ^^^djjll>S>ST\>]>]^DJJLL^^^r   z/opt/rocm/lib/zcannot locate z after attempted paths )r   amdlibhip_pathrJ   r!   r;   existsRuntimeErrorr<   joindirname__file__appendsitegetsitepackagesgetusersitepackagesENABLE_USER_SITEgetenvrK   
subprocesscheck_outputdecoderI   CalledProcessErrorFileNotFoundError
splitlines)env_libhip_pathmmapped_pathpaths	local_librW   site_packages	user_siter;   env_ld_library_pathdfenv_hip_pathhip_lib_pathhip_rootenv_rocm_pathrocm_lib_pathlibslocsloccommon_install_pathr,   s                       @r   _get_path_to_hip_runtime_dylibrt   A   s   H  )// k##H-- 	#"'..2Q2Q 	#""i/ii_giijjj 8AAL n7>>,'' 	 l\llbjllmmmE RW__X66xHHI	w~~i   	LLKKK ((**M((**I 4"m3  w||D'5(;;7>>$ 	KKKT )$566 $**3// 	 	AQ))Aw~~a   LLOOOO 9Z((L #w||L%BB7>>,'' 	 \"""	*K+BCCJJLLRRTT 	'7<<%BBLw~~l++ $##LL&&&)+<=   
 Ik**M $]E8DD7>>-(( 	!  ]### "$4d#;<<CC8CTTD _^^^):):^^^D  7>># 	JJJS ',,'7BB	w~~)** #""	LL$%%%
PPPPP
Q
QQs   0A=M .M MMc                   $     e Zd Z fdZd Z xZS )HIPUtilsc                     t          | d          s-t          t          |                               |           | _        | j        S )Ninstance)hasattrsuperrv   __new__rx   )cls	__class__s    r   r{   zHIPUtils.__new__   s<    sJ'' 	= 3//77<<CL|r   c                 6   t                      }t          t          j                            t
          d                                                    }|                    d|d          }t          |dt                    }|j
        | _
        |j        | _        d S )Nzdriver.cz/*py_libhip_search_path*/r   	hip_utilssrcr#   include_dirs)rt   r   r!   r;   rS   rT   	read_textreplacer   r   load_binaryget_device_properties)selfrP   r   mods       r   __init__zHIPUtils.__init__   s}    46627<<4455??AA kk5{AFF%#Kl[[[?%(%>"""r   )r   r   r   r{   r   __classcell__r}   s   @r   rv   rv      sG            
	? 	? 	? 	? 	? 	? 	?r   rv   c                 N    | d         dk    rdS dddddddd	d
ddddddd|          S )Nr   *hipDeviceptr_tint8_tint16_tint32_tint64_tuint8_tuint16_tuint32_tuint64_tdouble)i1i8i16i32i64u1u8u16u32u64fp16bf16fp32f32fp64 )tys    r   	ty_to_cppr      s[    	!u||   	!
 
r   r   r   r   )r   r   r   r   r   	pack_fp16	pack_bf16	pack_fp32	pack_fp64piiiKKOOOOOc                    d }fdfdfdd t           ||                                                    D             }d                    fd|                                D                       }t          |z   }d                    t	          |                                                    }t          t          t          |                    d                              }d	 t          |          D             }t          |          d
k    r4dd                    d |
                                D                       z   nd}g }|
                                D ]b\  }}	|	dk    r|	t          v r&|                    t          |	          d|            ;|                    t          |	           d|            cd                    |          }
g }|
                                D ]l\  }}	|	d
         dk    r|                    d| d           +|	t          v r|                    d| d           N|	dk    r|                    d|            md |
                                D             }t                      }t          t          t          |                              }d |
                                D             }|                    d           |                    d           d| dt          |
          d
k    rd|
z   nd dd                    |           d| d| dd                    fd|
                                D                        d | d!| d"d                    |           d#d$                    d% |
                                D                        d&t          |          d
k    rdd                    |          z   nd d'}|S )(Nc                 b   g }| D ](}t          |t                    r|                    d          r|                    d          dz   }t	          j        d|                                          }|                    d|z              t          d|z            D ]}|                    d           |                    d           t          |          D ]}|                    d	           t          |          D ]}|                    d           |                    |           *|S )
N
tensordesc,r   ztensordesc<([^[>]*)r      r   r   r   )	
isinstancestr
startswithcountrematchgrouprV   range)	signatureoutputsigndimdtype_s         r   _expand_signaturez(make_launcher.<locals>._expand_signature   s<     	# 	#C#s## #|(D(D #yy~~)!6<<BBDDcEk***q4x ) )AMM%((((d### t ) )AMM%((((t ) )AMM%(((() c""""r   c                 x    t          | t                    r#d                    t          |                     S | S )Nr   )r   tuplerS   map)r   _serialize_signatures    r   r   z+make_launcher.<locals>._serialize_signature   s7    c5!! 	<88C 4c::;;;
r   c                     t          | t                    r)d                    t          |                     }d| dS | d         dk    rdS | dk    rdS t	          |           S )Nr   []r   r   z	PyObject*	constexprr   r   rS   r   r   )r   val_extracted_types     r   r   z&make_launcher.<locals>._extracted_type   sl    b%   	((33344Cs:::a5C<<;;}}r   c                     t          | t                    r)d                    t          |                     }d| dS | d         dk    rdS | dk    rdS dd	d
dddddddd
t	          |                    S )N ()r   r   Or   ri   lbhiLBHIK)
r   longr   r   r   r   r   r   r   r   r   )r   r   	format_ofs     r   r   z make_launcher.<locals>.format_of	  s    b%   	''#i,,--Cs:::a5C<<33
 
 B-- 	r   c                     i | ]\  }}||	S r   r   )rL   idxss      r   
<dictcomp>z!make_launcher.<locals>.<dictcomp>  s    WWWFCaWWWr   r   c                 &    g | ]} |          S r   r   )rL   r   r   s     r   rN   z!make_launcher.<locals>.<listcomp>   s!    FFFR99R==FFFr   r   c                     i | ]\  }}||	S r   r   )rL   r   r   s      r   r   z!make_launcher.<locals>.<dictcomp>$  s    777$!QA777r   r   , c              3   &   K   | ]\  }}d | V  dS )z&_argNr   rL   r   r   s      r   	<genexpr>z make_launcher.<locals>.<genexpr>%  s,       L LB L L L L L Lr   r   z argr   ptr_infoz.dev_ptr_arg_storagec           
      n    g | ]2\  }}|t           v t           |          d | dt          |          d| d3S ) _argz_storage = z(_argz);)FLOAT_STORAGE_TYPEFLOAT_PACK_FUNCTIONr   s      r   rN   z!make_launcher.<locals>.<listcomp>:  s_       Ar### b!YYYY6I"6MYYTUYYY###r   c                 *    g | ]\  }}|d k    d| S )r   z&argr   r   s      r   rN   z!make_launcher.<locals>.<listcomp>D  s,    MMMUQ2;L;LjQjj;L;L;Lr   z&global_scratchz&profile_scratcha\  
#define __HIP_PLATFORM_AMD__
#include <hip/hip_runtime.h>
#include <hip/hip_runtime_api.h>
#include <Python.h>
#include <dlfcn.h>
#include <stdbool.h>
#include <dlfcn.h>

// The list of paths to search for the HIP runtime library. The caller Python
// code should substitute the search path placeholder.
static const char *hipLibSearchPaths[] = {"a  "};

// The list of HIP dynamic library symbols and their signature we are interested
// in this file.
#define HIP_SYMBOL_LIST(FOR_EACH_ERR_FN, FOR_EACH_STR_FN)                     \
  FOR_EACH_STR_FN(hipGetLastError)                                            \
  FOR_EACH_STR_FN(hipGetErrorString, hipError_t hipError)                     \
  FOR_EACH_ERR_FN(hipModuleLaunchKernel, hipFunction_t f,                     \
                  unsigned int gridDimX, unsigned int gridDimY,               \
                  unsigned int gridDimZ, unsigned int blockDimX,              \
                  unsigned int blockDimY, unsigned int blockDimZ,             \
                  unsigned int sharedMemBytes, hipStream_t stream,            \
                  void **kernelParams, void **extra)                          \
  FOR_EACH_ERR_FN(hipModuleLaunchCooperativeKernel, hipFunction_t f,          \
                  unsigned int gridDimX, unsigned int gridDimY,               \
                  unsigned int gridDimZ, unsigned int blockDimX,              \
                  unsigned int blockDimY, unsigned int blockDimZ,             \
                  unsigned int sharedMemBytes, hipStream_t stream,            \
                  void **kernelParams, void **extra)                          \
  FOR_EACH_ERR_FN(hipPointerGetAttribute, void *data,                         \
                  hipPointer_attribute attribute, hipDeviceptr_t ptr)

// The HIP symbol table for holding resolved dynamic library symbols.
struct HIPSymbolTable {
#define DEFINE_EACH_ERR_FIELD(hipSymbolName, ...)                             \
  hipError_t (*hipSymbolName)(__VA_ARGS__);
#define DEFINE_EACH_STR_FIELD(hipSymbolName, ...)                             \
  const char *(*hipSymbolName)(__VA_ARGS__);

  HIP_SYMBOL_LIST(DEFINE_EACH_ERR_FIELD, DEFINE_EACH_STR_FIELD)
};

static struct HIPSymbolTable hipSymbolTable;

bool initSymbolTable() {
  // Use the HIP runtime library loaded into the existing process if it exits.
  void *lib = dlopen("libamdhip64.so", RTLD_NOLOAD);

  // Otherwise, go through the list of search paths to dlopen the first HIP
  // driver library.
  if (!lib) {
    int n = sizeof(hipLibSearchPaths) / sizeof(hipLibSearchPaths[0]);
    for (int i = 0; i < n; ++i) {
      void *handle = dlopen(hipLibSearchPaths[i], RTLD_LAZY | RTLD_LOCAL);
      if (handle) {
        lib = handle;
      }
    }
  }
  if (!lib) {
    PyErr_SetString(PyExc_RuntimeError, "cannot open libamdhip64.so");
    return false;
  }

  typedef hipError_t (*hipGetProcAddress_fn)(
      const char *symbol, void **pfn, int hipVersion, uint64_t hipFlags,
      hipDriverProcAddressQueryResult *symbolStatus);
  hipGetProcAddress_fn hipGetProcAddress;
  dlerror(); // Clear existing errors
  const char *error = NULL;
  *(void **)&hipGetProcAddress = dlsym(lib, "hipGetProcAddress");
  error = dlerror();
  if (error) {
    PyErr_SetString(PyExc_RuntimeError,
                    "cannot query 'hipGetProcAddress' from libamdhip64.so");
    dlclose(lib);
    return false;
  }

  // Resolve all symbols we are interested in.
  int hipVersion = HIP_VERSION;
  uint64_t hipFlags = 0;
  hipDriverProcAddressQueryResult symbolStatus;
  hipError_t status = hipSuccess;
#define QUERY_EACH_FN(hipSymbolName, ...)                                        status = hipGetProcAddress(#hipSymbolName,                                                                (void **)&hipSymbolTable.hipSymbolName,                                        hipVersion, hipFlags, &symbolStatus);               if (status != hipSuccess) {                                                     PyErr_SetString(PyExc_RuntimeError,                                                            "cannot get address for '" #hipSymbolName                                      "' from libamdhip64.so");                                      dlclose(lib);                                                                  return false;                                                                }

  HIP_SYMBOL_LIST(QUERY_EACH_FN, QUERY_EACH_FN)

  return true;
}

static inline void gpuAssert(hipError_t code, const char *file, int line)
{
   if (code != HIP_SUCCESS)
   {
      const char* prefix = "Triton Error [HIP]: ";
       const char* str = hipSymbolTable.hipGetErrorString(code);
      char err[1024] = {0};
      snprintf(err, 1024, "%s Code: %d, Messsage: %s", prefix, code, str );
      PyErr_SetString(PyExc_RuntimeError, err);
   }
}

#define HIP_CHECK(ans) { gpuAssert((ans), __FILE__, __LINE__); }

static void _launch(int gridX, int gridY, int gridZ, int num_warps, int num_ctas, int launch_cooperative_grid, int clusterDimX, int clusterDimY, int clusterDimZ, int shared_memory, hipStream_t stream, hipFunction_t function, hipDeviceptr_t profile_scratchz>) {
  hipDeviceptr_t global_scratch = 0;
  void *params[] = { z };
  if (gridX*gridY*gridZ > 0 && launch_cooperative_grid) {
    HIP_CHECK(hipSymbolTable.hipModuleLaunchCooperativeKernel(function, gridX, gridY, gridZ, z*num_warps, 1, 1, shared_memory, stream, params, 0));
    return;
  }
  if (gridX*gridY*gridZ > 0) {
    HIP_CHECK(hipSymbolTable.hipModuleLaunchKernel(function, gridX, gridY, gridZ, ae
  *num_warps, 1, 1, shared_memory, stream, params, 0));
  }
}

typedef struct _DevicePtrInfo {
    hipDeviceptr_t dev_ptr;
    bool valid;
} DevicePtrInfo;

static PyObject* data_ptr_str = NULL;

static inline DevicePtrInfo getPointer(PyObject *obj, int idx) {
  DevicePtrInfo ptr_info;
  hipError_t status = hipSuccess;
  ptr_info.dev_ptr = 0;
  ptr_info.valid = true;
  if (PyLong_Check(obj)) {
    ptr_info.dev_ptr = (hipDeviceptr_t)PyLong_AsUnsignedLongLong(obj);
    return ptr_info;
  }
  if (obj == Py_None) {
    // valid nullptr
    return ptr_info;
  }
  PyObject *ret = PyObject_CallMethodNoArgs(obj, data_ptr_str);
  if (!ret) {
    PyErr_SetString(PyExc_TypeError, "Pointer argument must be either uint64 or have data_ptr method");
    ptr_info.valid = false;
    goto cleanup;
  }
  if (!PyLong_Check(ret)) {
    PyErr_SetString(PyExc_TypeError, "data_ptr method of Pointer object must return 64-bit int");
    ptr_info.valid = false;
    goto cleanup;
  }
  ptr_info.dev_ptr = (hipDeviceptr_t)PyLong_AsUnsignedLongLong(ret);
  if (!ptr_info.dev_ptr)
    goto cleanup;
  uint64_t dev_ptr;
  status = hipSymbolTable.hipPointerGetAttribute(&dev_ptr, HIP_POINTER_ATTRIBUTE_DEVICE_POINTER, ptr_info.dev_ptr);
  if (status == hipErrorInvalidValue) {
      PyErr_Format(PyExc_ValueError,
                   "Pointer argument (at %d) cannot be accessed from Triton (cpu tensor?)", idx);
      ptr_info.valid = false;
      // Clear and ignore HIP error
      (void)hipSymbolTable.hipGetLastError();
  }
  ptr_info.dev_ptr = (hipDeviceptr_t)dev_ptr;
cleanup:
  Py_DECREF(ret);
  return ptr_info;
}

static uint16_t pack_fp16(double f) {
    uint16_t result;
    // from https://github.com/python/pythoncapi-compat/blob/5e317108f872c904eb726cb8d560dcadbdf88a72/pythoncapi_compat.h#L482-L492
#if 0x030600B1 <= PY_VERSION_HEX && PY_VERSION_HEX <= 0x030B00A1 && !defined(PYPY_VERSION)
    _PyFloat_Pack2(f, (unsigned char*)&result, 1);
#else
    PyFloat_Pack2(f, (char*)&result, 1);
#endif
    return result;
}

static uint16_t pack_bf16(double f) {
    float f32 = (float)f;
    uint32_t u32 = *(uint32_t*)&f32;
    return (uint16_t)(u32 >> 16);
}

static uint32_t pack_fp32(double f) {
    float f32 = (float)f;
    return *(uint32_t*)&f32;
}

static uint64_t pack_fp64(double f) {
    return *(uint64_t*)&f;
}

static PyObject* launch(PyObject* self, PyObject* args) {
  int gridX, gridY, gridZ;
  uint64_t _stream;
  uint64_t _function;
  int launch_cooperative_grid;
  PyObject *profile_scratch_obj = NULL;
  PyObject *launch_enter_hook = NULL;
  PyObject *launch_exit_hook = NULL;
  PyObject *kernel_metadata = NULL;
  PyObject *launch_metadata = NULL;
   c                 8    g | ]\  }} |           d | dS )r   ; r   )rL   r   r   r   s      r   rN   z!make_launcher.<locals>.<listcomp>  s8    OOOEAr##//!///OOOr   z
  if(!PyArg_ParseTuple(args, "a,  ", &launch_cooperative_grid,
                                           &gridX, &gridY, &gridZ, &_stream, &_function, &profile_scratch_obj,
                                           &kernel_metadata, &launch_metadata,
                                           &launch_enter_hook, &launch_exit_hook z)) {
    return NULL;
  }

  a  

  // extract kernel metadata
  int num_warps, num_ctas, shared_memory, clusterDimX, clusterDimY, clusterDimZ;
  if (!PyArg_ParseTuple(kernel_metadata, "iiiiii", &num_warps, &num_ctas, &shared_memory, &clusterDimX, &clusterDimY, &clusterDimZ)) {
    return NULL;
  }
  // extract launch metadata
  if (launch_enter_hook != Py_None){
    PyObject* ret = PyObject_CallOneArg(launch_enter_hook, launch_metadata);
    if (!ret)
      return NULL;
    Py_DECREF(ret);
  }

  hipDeviceptr_t profile_scratch = 0;
  if (profile_scratch_obj != Py_None) {
    DevicePtrInfo profile_scratch_info = getPointer(profile_scratch_obj, -1);
    if (!profile_scratch_info.valid) {
      return NULL;
    }
    profile_scratch = profile_scratch_info.dev_ptr;
  }

  // raise exception asap
  r   c                 N    g | ]"\  }}|d          dk    rd| d| d| d| d	nd#S )r   r   zDevicePtrInfo ptr_infoz = getPointer(_argr   z); if (!ptr_infoz.valid) return NULL;r   r   r   s      r   rN   z!make_launcher.<locals>.<listcomp>=  s~      d  d  d  IN  IJ  LNoqrsotx{o{o{kqkkAkkkkTUkkkk  BD  d  d  dr   z;
  _launch(gridX, gridY, gridZ, num_warps, num_ctas, launch_cooperative_grid, clusterDimX, clusterDimY, clusterDimZ, shared_memory, (hipStream_t)_stream, (hipFunction_t)_function, (hipDeviceptr_t)profile_scratcha  );

  if(launch_exit_hook != Py_None){
    PyObject* ret = PyObject_CallOneArg(launch_exit_hook, launch_metadata);
    if (!ret)
      return NULL;
    Py_DECREF(ret);
  }

  if(PyErr_Occurred()) {
    return NULL;
  }
  Py_RETURN_NONE;
}

static PyMethodDef ModuleMethods[] = {
  {"launch", launch, METH_VARARGS, "Entry point for all kernels with this signature"},
  {NULL, NULL, 0, NULL} // sentinel
};

static struct PyModuleDef ModuleDef = {
  PyModuleDef_HEAD_INIT,
  "__triton_launcher",
  NULL, //documentation
  -1, //size
  ModuleMethods
};

PyMODINIT_FUNC PyInit___triton_launcher(void) {
  if (!initSymbolTable()) {
    return NULL;
  }
  PyObject *m = PyModule_Create(&ModuleDef);
  if(m == NULL) {
    return NULL;
  }
  data_ptr_str = PyUnicode_InternFromString("data_ptr");
  if(data_ptr_str == NULL) {
    return NULL;
  }
  PyModule_AddFunctions(m, ModuleMethods);
  return m;
}
)	enumeratevaluesrS   _BASE_ARGS_FORMATr   listfilterboolrK   r&   itemsr   rV   r   rt   r   )	constantsr   	warp_sizer   args_formatformat	args_listarg_decl_listr   r   	arg_declsinternal_args_listfloat_storage_declsrP   paramsr   r   r   r   s                   @@@r   make_launcherr     s     6    
        * XWi0A0A)BRBRBTBT0U0U&V&VWWWI''FFFF93C3C3E3EFFFGGK,F193C3C3E3EFFGGIVD)//#"6"67788I77)I"6"6777IPST]P^P^abPbPbtyy L L)//:K:K L L LLLLLhjI M"" < <2###  $6r$:!C!C!C!CDDDD  IbMM!:!:q!:!:;;;;		-((I"" 2 22a5C<<%%&<&<&<&<====%%%%%&8Q&8&8&89999;%%jQjj111 __&&   122K %I''((FMMioo&7&7MMMF
MM#$$$
MM$%%%b .9b bh UX  Yb  Uc  Uc  fg  Ug  Ug  AE  HQ  AQ  AQ  moib bl yy((mb bp _hqb bx T]yb bj 88OOOOY__=N=NOOOPPkb bl !'mb br S\sb bz 88  {b bl 99  d  d  R[  Ra  Ra  Rc  Rc  d  d  d  e  emb bn |  @R  |S  |S  VW  |W  |W  TX  [_  [d  [d  ew  [x  [x  Tx  Tx  ]_ob b bCF	 Jr   c                       fd}|S )zN
    Replace all tensor descriptors with the base ptr, shape, and strides
    c                  p   | d t          t                             }| t          t                    d          }g }|D ]o}t          |t                    rC|                    |j        g|j        |j        |j        dk    |j        |j                   Z|	                    |           p g ||R  S )Nnan)
r&   r   r   r	   extendbaseshapestridespaddingrV   )args	meta_argsraw_kernel_args
final_argsarglaunchers        r   innerz,wrap_handle_tensor_descriptor.<locals>.innerr  s    0#/0001	s#455667
" 
	' 
	'C#/00 	' !!38"vci"v#+"vs{V[G["v^a^g"vjmju"vwwww!!#&&&&x00Z0000r   r   )r  r  s   ` r   wrap_handle_tensor_descriptorr  m  s#    
1 1 1 1 1" Lr   c                       e Zd Zd Zd ZdS )HIPLauncherc                     t          d          rj        nt                      }fdfd|                                D             }d j                                        D             }t          |||j                  t          dt                    }t          d |
                                D                       }|rt          |j                  n|j        | _        |j        | _        |j        | _        |j        | _        d S )Nr   c                 r    t          | t                    r j        j                            |           fn| S N)r   r   fn	arg_namesindex)xr   s    r   <lambda>z&HIPLauncher.__init__.<locals>.<lambda>  s2    Z3=O=OVSV-33A6699UV r   c                 .    i | ]\  }} |          |S r   r   )rL   r   valuearg_idxs      r   r   z(HIPLauncher.__init__.<locals>.<dictcomp>  s'    MMMZS%WWS\\5MMMr   c                     i | ]\  }}||	S r   r   )rL   r   r!  s      r   r   z(HIPLauncher.__init__.<locals>.<dictcomp>  s    HHHJCS%HHHr   __triton_launcherr   c              3   j   K   | ].}t          |t                    o|                    d           V  /dS )r   N)r   r   r   )rL   r   s     r   r   z'HIPLauncher.__init__.<locals>.<genexpr>  s>      !v!v\_*S#"6"6"W3>>,;W;W!v!v!v!v!v!vr   )ry   r   dictr   r   r  r   r   r   anyr   r  launchlaunch_cooperative_gridprofile_scratch_sizeprofile_scratch_align)r   r   metadatar   r   r   has_tensor_desc_argr"  s    `     @r   r   zHIPLauncher.__init__  s   %,S+%>%>JCMMDFF	VVVVMMMM9??;L;LMMM	HH#-2E2E2G2GHHH	Iy(2DEE%#4GVbccc!!v!vclcscscucu!v!v!vvvCVf3CJ???\_\f'/'G$$,$A!%-%C"""r   c           	          fd} || j         | j        t          j                  } | j        | j        ||g|R   d S )Nc                 p    | dk    r.z  z  }|| z  }|                                 } |||	          S d S Nr   )get)
r(   align	allocator	grid_size
alloc_sizealloc_fngridXgridYgridZstreams
         r   allocate_scratchz.HIPLauncher.__call__.<locals>.allocate_scratch  sJ    axx!EME1	&-
$==??x
E6:::4r   )r*  r+  r   _profile_allocatorr(  r)  )	r   r7  r8  r9  r:  functionr  r;  profile_scratchs	    ````    r   __call__zHIPLauncher.__call__  s    	 	 	 	 	 	 	 	 +*4+DdF`+6+IK K 	D0%vxYhpkoppppppr   N)r   r   r   r   r?  r   r   r   r  r    s:        D D Dq q q q qr   r  c                   h     e Zd Z fdZd Zed             ZdedefdZd Z	d Z
d	 Zd
 Zd Z xZS )	HIPDriverc                     t                                                       t                      | _        t          | _        d S r  )rz   r   rv   utilsr  launcher_cls)r   r}   s    r   r   zHIPDriver.__init__  s2    ZZ
'r   c                     dd l }|j        S r0  )r?   cudar   r?   s     r   get_device_interfacezHIPDriver.get_device_interface  s    zr   c                  |    	 dd l } | j                                        o| j        j        d uS # t
          $ r Y dS w xY w)Nr   F)r?   rF  is_availableversionhipImportError)r?   s    r   	is_activezHIPDriver.is_active  sU    	LLL:**,,P%-2C42OP 	 	 	55	s   *- 
;;r   returnc                      t          |          S r  )r   )r   r   s     r   map_python_to_cpp_typez HIPDriver.map_python_to_cpp_type  s    }}r   c                     |                                  }| j                            |          }t          j        j        p|d         }|d         }t          d|                    d          d         |          S )NarchwarpSizerL  rA   r   )get_current_devicerC  r   r   runtimeoverride_archr   rK   )r   devicedevice_propertiesrS  r   s        r   get_current_targetzHIPDriver.get_current_target  si    ((** J<<VDD}*G.?.G%j1	

3 2I>>>r   c                 Z    dd l }|                    d|                                           S )Nr   rF  )r?   rX  rU  rG  s     r   get_active_torch_devicez!HIPDriver.get_active_torch_device  s+    ||FD$;$;$=$=>>>r   c                     ddl m} |S )Nr   )do_bench)triton.testingr^  )r   r^  s     r   get_benchmarkerzHIPDriver.get_benchmarker  s    ++++++r   c                 h    dd l }d}|                    t          |dz            |j        d          S )Nr   i      rF  )r   rX  )r?   emptyint)r   r?   
cache_sizes      r   get_empty_cache_for_benchmarkz'HIPDriver.get_empty_cache_for_benchmark  s8     '
{{3zQ//uy{PPPr   c                 .    |                                  d S r  )zero_)r   caches     r   clear_cachezHIPDriver.clear_cache  s    r   )r   r   r   r   rH  staticmethodrN  r   rQ  rZ  r\  r`  rf  rj  r   r   s   @r   rA  rA    s        ( ( ( ( (
     \     ? ? ?? ? ?
  Q Q Q      r   rA  )%	functoolsr!   r\   r   pathlibr   tritonr   triton.backends.compilerr   triton.backends.driverr   triton.runtimer   triton.runtime.buildr   triton.tools.tensor_descriptorr	   r;   rT   realpathrU   rS   r   r<   	lru_cachert   objectrv   r   r   r   r   r  r  r  rA  r   r   r   <module>rw     s,       				     				             . . . . . . , , , , , , & & & & & & 8 8 8 8 8 8 ; ; ; ; ; ;
'//"'**844
5
5Wi001- - -` \R \R \R~? ? ? ? ?v ? ? ?(
 
 
.       " M M M`  2q q q q q& q q q@. . . . .	 . . . . .r   