Browse Source

starpupy: define a new exception for starpupy module

HE Kun 4 years ago
parent
commit
281b2d20db
1 changed files with 34 additions and 8 deletions
  1. 34 8
      starpupy/src/starpu_task_wrapper.c

+ 34 - 8
starpupy/src/starpu_task_wrapper.c

@@ -29,6 +29,7 @@
 
 /*********************Functions passed in task_submit wrapper***********************/
 
+static PyObject *StarpupyError; /*starpupy error exception*/
 static PyObject *asyncio_module; /*python asyncio module*/
 
 static char* starpu_cloudpickle_dumps(PyObject *obj, PyObject **obj_bytes, Py_ssize_t *obj_data_size)
@@ -564,10 +565,15 @@ static PyObject* starpu_task_submit_wrapper(PyObject *self, PyObject *args)
 	task->callback_func=&cb_func;
 
 	/*call starpu_task_submit method*/
+	int ret;
 	Py_BEGIN_ALLOW_THREADS
-		int ret = starpu_task_submit(task);
-		assert(ret==0);
+	ret = starpu_task_submit(task);
 	Py_END_ALLOW_THREADS
+	if (ret!=0)
+	{
+		PyErr_Format(StarpupyError, "Unexpected value %d returned for starpu_task_submit", ret);
+		return NULL;
+	}
 
 	if (strcmp(tp_perf, "PyCapsule")==0)
 	{
@@ -715,15 +721,35 @@ static struct PyModuleDef starpupymodule =
 PyMODINIT_FUNC
 PyInit_starpupy(void)
 {
+	PyObject *m;
+
+	/*module import initialization*/
+	m = PyModule_Create(&starpupymodule);
+	if (m == NULL)
+		return NULL;
+
+	StarpupyError = PyErr_NewException("StarPUPy.Error", NULL, NULL);
+	Py_XINCREF(StarpupyError);
+    if (PyModule_AddObject(m, "error", StarpupyError) < 0) {
+        Py_XDECREF(StarpupyError);
+        Py_CLEAR(StarpupyError);
+        Py_DECREF(m);
+        return NULL;
+    }
+
 #if PY_MAJOR_VERSION < 3 || (PY_MAJOR_VERSION == 3 && PY_MINOR_VERSION < 9)
 	PyEval_InitThreads();
 #endif
-	//PyThreadState* st = PyEval_SaveThread();
-	Py_BEGIN_ALLOW_THREADS
 	/*starpu initialization*/
-	int ret = starpu_init(NULL);
-	assert(ret==0);
+	int ret;
+	Py_BEGIN_ALLOW_THREADS
+	ret = starpu_init(NULL);
 	Py_END_ALLOW_THREADS
+	if (ret!=0)
+	{
+		PyErr_Format(StarpupyError, "Unexpected value %d returned for starpu_init", ret);
+		return NULL;
+	}
 
 	/*python asysncio import*/
 	asyncio_module = PyImport_ImportModule("asyncio");
@@ -732,7 +758,7 @@ PyInit_starpupy(void)
 	/*numpy import array*/
 	import_array();
 #endif
-	/*module import initialization*/
-	return PyModule_Create(&starpupymodule);
+	
+	return m;
 }
 /***********************************************************************************/