substitute_indexing.jl 1.7 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253
  1. function substitute_indexing(expr :: StarpuExpr)
  2. function func_to_run(x :: StarpuExpr)
  3. if !isa(x, StarpuExprRef)
  4. return x
  5. end
  6. if !isa(x.ref, StarpuExprVar)
  7. error("Only variable indexing is allowed") #TODO allow more ?
  8. end
  9. nb_indexes = length(x.indexes)
  10. if (nb_indexes >= 3)
  11. error("Indexing with more than 2 indexes is not allowed") # TODO : blocks
  12. end
  13. if (nb_indexes == 0)
  14. return x
  15. elseif nb_indexes == 1
  16. new_index = StarpuExprCall(:-, [x.indexes[1], StarpuExprValue(1)]) #TODO : add field "offset" from STARPU_VECTOR_GET interface
  17. #TODO : detect when it is a matrix used with one index only
  18. return StarpuExprRef(x.ref, [new_index])
  19. elseif nb_indexes == 2
  20. var_name = String(x.ref.name)
  21. if !ismatch(r"ptr_", var_name) || isempty(var_name[5:end])
  22. error("Invalid variable ($var_name) for multiple index dereferencing")
  23. end
  24. var_id = var_name[5:end]
  25. ld_name = Symbol("ld_", var_id) # TODO : check if this variable is legit (var_name must refer to a matrix)
  26. new_index = x.indexes[2]
  27. new_index = StarpuExprCall(:(-), [new_index, StarpuExprValue(1)])
  28. new_index = StarpuExprCall(:(*), [new_index, StarpuExprVar(ld_name)])
  29. new_index = StarpuExprCall(:(+), [x.indexes[1], new_index])
  30. new_index = StarpuExprCall(:(-), [new_index, StarpuExprValue(1)])
  31. return StarpuExprRef(x.ref, [new_index])
  32. end
  33. end
  34. return apply(func_to_run, expr)
  35. end