@@ -386,7 +386,7 @@ def place(arr, mask, vals):
386386 return call_origin (numpy .place , arr , mask , vals )
387387
388388
389- def put (input , ind , v , mode = 'raise' ):
389+ def put (x1 , ind , v , mode = 'raise' ):
390390 """
391391 Replaces specified elements of an array with given values.
392392 For full documentation refer to :obj:`numpy.put`.
@@ -397,22 +397,21 @@ def put(input, ind, v, mode='raise'):
397397 Not supported parameter mode.
398398 """
399399
400- if not use_origin_backend (input ):
401- if not isinstance (input , dparray ):
402- pass
403- elif mode != 'raise' :
400+ x1_desc = dpnp .get_dpnp_descriptor (x1 )
401+ if x1_desc :
402+ if mode != 'raise' :
404403 pass
405404 elif type (ind ) != type (v ):
406405 pass
407- elif numpy .max (ind ) >= input .size or numpy .min (ind ) + input .size < 0 :
406+ elif numpy .max (ind ) >= x1_desc .size or numpy .min (ind ) + x1_desc .size < 0 :
408407 pass
409408 else :
410- return dpnp_put (input , ind , v )
409+ return dpnp_put (x1_desc , ind , v )
411410
412- return call_origin (numpy .put , input , ind , v , mode )
411+ return call_origin (numpy .put , x1 , ind , v , mode )
413412
414413
415- def put_along_axis (arr , indices , values , axis ):
414+ def put_along_axis (x1 , indices , values , axis ):
416415 """
417416 Put values into the destination array by matching 1d index and data slices.
418417 For full documentation refer to :obj:`numpy.put_along_axis`.
@@ -422,62 +421,25 @@ def put_along_axis(arr, indices, values, axis):
422421 :obj:`take_along_axis` : Take values from the input array by matching 1d index and data slices.
423422 """
424423
425- if not use_origin_backend (arr ):
426- if not isinstance (arr , dparray ):
427- pass
428- elif not isinstance (indices , dparray ):
429- pass
430- elif arr .ndim != indices .ndim :
424+ x1_desc = dpnp .get_dpnp_descriptor (x1 )
425+ indices_desc = dpnp .get_dpnp_descriptor (indices )
426+ values_desc = dpnp .get_dpnp_descriptor (values )
427+ if x1_desc and indices_desc and values_desc :
428+ if x1_desc .ndim != indices_desc .ndim :
431429 pass
432430 elif not isinstance (axis , int ):
433431 pass
434- elif axis >= arr .ndim :
432+ elif axis >= x1_desc .ndim :
435433 pass
436- elif not isinstance ( values , ( dparray , tuple , list )) and not dpnp . isscalar ( values ) :
434+ elif indices_desc . size != values_desc . size :
437435 pass
438- elif not dpnp .isscalar (values ) and ((isinstance (values , dparray ) and indices .size != values .size ) or
439- ((isinstance (values , (tuple , list )) and indices .size != len (values )))):
440- pass
441- elif arr .ndim == indices .ndim :
442- val_list = []
443- for i in list (indices .shape )[:- 1 ]:
444- if i == 1 :
445- val_list .append (True )
446- else :
447- val_list .append (False )
448- if not all (val_list ):
449- pass
450- else :
451- if dpnp .isscalar (values ):
452- values_size = 1
453- values_ = dparray (values_size , dtype = arr .dtype )
454- values_ [0 ] = values
455- elif isinstance (values , dparray ):
456- values_ = values
457- else :
458- values_size = len (values )
459- values_ = dparray (values_size , dtype = arr .dtype )
460- for i in range (values_size ):
461- values_ [i ] = values [i ]
462- return dpnp_put_along_axis (arr , indices , values_ , axis )
463436 else :
464- if dpnp .isscalar (values ):
465- values_size = 1
466- values_ = dparray (values_size , dtype = arr .dtype )
467- values_ [0 ] = values
468- elif isinstance (values , dparray ):
469- values_ = values
470- else :
471- values_size = len (values )
472- values_ = dparray (values_size , dtype = arr .dtype )
473- for i in range (values_size ):
474- values_ [i ] = values [i ]
475- return dpnp_put_along_axis (arr , indices , values_ , axis )
437+ return dpnp_put_along_axis (x1_desc , indices_desc , values_desc , axis )
476438
477- return call_origin (numpy .put_along_axis , arr , indices , values , axis )
439+ return call_origin (numpy .put_along_axis , x1 , indices , values , axis )
478440
479441
480- def putmask (arr , mask , values ):
442+ def putmask (x1 , mask , values ):
481443 """
482444 Changes elements of an array based on conditional and input values.
483445 For full documentation refer to :obj:`numpy.putmask`.
@@ -487,17 +449,13 @@ def putmask(arr, mask, values):
487449 Input arrays ``arr``, ``mask`` and ``values`` are supported as :obj:`dpnp.ndarray`.
488450 """
489451
490- if not use_origin_backend (arr ):
491- if not isinstance (arr , dparray ):
492- pass
493- elif not isinstance (mask , dparray ):
494- pass
495- elif not isinstance (values , dparray ):
496- pass
497- else :
498- return dpnp_putmask (arr , mask , values )
452+ x1_desc = dpnp .get_dpnp_descriptor (x1 )
453+ mask_desc = dpnp .get_dpnp_descriptor (mask )
454+ values_desc = dpnp .get_dpnp_descriptor (values )
455+ if x1_desc and mask_desc and values_desc :
456+ return dpnp_putmask (x1 , mask , values )
499457
500- return call_origin (numpy .putmask , arr , mask , values )
458+ return call_origin (numpy .putmask , x1 , mask , values )
501459
502460
503461def select (condlist , choicelist , default = 0 ):
@@ -510,6 +468,7 @@ def select(condlist, choicelist, default=0):
510468 Arrays of input lists are supported as :obj:`dpnp.ndarray`.
511469 Parameter ``default`` are supported only with default values.
512470 """
471+
513472 if not use_origin_backend ():
514473 if not isinstance (condlist , list ):
515474 pass
@@ -537,7 +496,7 @@ def select(condlist, choicelist, default=0):
537496 return call_origin (numpy .select , condlist , choicelist , default )
538497
539498
540- def take (input , indices , axis = None , out = None , mode = 'raise' ):
499+ def take (x1 , indices , axis = None , out = None , mode = 'raise' ):
541500 """
542501 Take elements from an array.
543502 For full documentation refer to :obj:`numpy.take`.
@@ -554,24 +513,22 @@ def take(input, indices, axis=None, out=None, mode='raise'):
554513 :obj:`take_along_axis` : Take elements by matching the array and the index arrays.
555514 """
556515
557- if not use_origin_backend (input ):
558- if not isinstance (input , dparray ):
559- pass
560- elif not isinstance (indices , dparray ):
561- pass
562- elif axis is not None :
516+ x1_desc = dpnp .get_dpnp_descriptor (x1 )
517+ indices_desc = dpnp .get_dpnp_descriptor (indices )
518+ if x1_desc and indices_desc :
519+ if axis is not None :
563520 pass
564521 elif out is not None :
565522 pass
566523 elif mode != 'raise' :
567524 pass
568525 else :
569- return dpnp_take (input , indices )
526+ return dpnp_take (x1_desc , indices_desc )
570527
571- return call_origin (numpy .take , input , indices , axis , out , mode )
528+ return call_origin (numpy .take , x1 , indices , axis , out , mode )
572529
573530
574- def take_along_axis (arr , indices , axis ):
531+ def take_along_axis (x1 , indices , axis ):
575532 """
576533 Take values from the input array by matching 1d index and data slices.
577534 For full documentation refer to :obj:`numpy.take_along_axis`.
@@ -582,32 +539,30 @@ def take_along_axis(arr, indices, axis):
582539 :obj:`put_along_axis` : Put values into the destination array by matching 1d index and data slices.
583540 """
584541
585- if not use_origin_backend (arr ):
586- if not isinstance (arr , dparray ):
587- pass
588- elif not isinstance (indices , dparray ):
589- pass
590- elif arr .ndim != indices .ndim :
542+ x1_desc = dpnp .get_dpnp_descriptor (x1 )
543+ indices_desc = dpnp .get_dpnp_descriptor (indices )
544+ if x1_desc and indices_desc :
545+ if x1_desc .ndim != indices_desc .ndim :
591546 pass
592547 elif not isinstance (axis , int ):
593548 pass
594- elif axis >= arr .ndim :
549+ elif axis >= x1_desc .ndim :
595550 pass
596- elif arr .ndim == indices .ndim :
551+ elif x1_desc .ndim == indices_desc .ndim :
597552 val_list = []
598- for i in list (indices .shape )[:- 1 ]:
553+ for i in list (indices_desc .shape )[:- 1 ]:
599554 if i == 1 :
600555 val_list .append (True )
601556 else :
602557 val_list .append (False )
603558 if not all (val_list ):
604559 pass
605560 else :
606- return dpnp_take_along_axis (arr , indices , axis )
561+ return dpnp_take_along_axis (x1 , indices , axis )
607562 else :
608- return dpnp_take_along_axis (arr , indices , axis )
563+ return dpnp_take_along_axis (x1 , indices , axis )
609564
610- return call_origin (numpy .take_along_axis , arr , indices , axis )
565+ return call_origin (numpy .take_along_axis , x1 , indices , axis )
611566
612567
613568def tril_indices (n , k = 0 , m = None ):
@@ -644,7 +599,7 @@ def tril_indices(n, k=0, m=None):
644599 return call_origin (numpy .tril_indices , n , k , m )
645600
646601
647- def tril_indices_from (arr , k = 0 ):
602+ def tril_indices_from (x1 , k = 0 ):
648603 """
649604 Return the indices for the lower-triangle of arr.
650605 See `tril_indices` for full details.
@@ -659,13 +614,12 @@ def tril_indices_from(arr, k=0):
659614 Diagonal offset (see `tril` for details).
660615 """
661616
662- is_arr_dparray = isinstance (arr , dparray )
663-
664- if (not use_origin_backend (arr ) and is_arr_dparray ):
617+ x1_desc = dpnp .get_dpnp_descriptor (x1 )
618+ if x1_desc :
665619 if isinstance (k , int ):
666- return dpnp_tril_indices_from (arr , k )
620+ return dpnp_tril_indices_from (x1_desc , k )
667621
668- return call_origin (numpy .tril_indices_from , arr , k )
622+ return call_origin (numpy .tril_indices_from , x1 , k )
669623
670624
671625def triu_indices (n , k = 0 , m = None ):
@@ -702,7 +656,7 @@ def triu_indices(n, k=0, m=None):
702656 return call_origin (numpy .triu_indices , n , k , m )
703657
704658
705- def triu_indices_from (arr , k = 0 ):
659+ def triu_indices_from (x1 , k = 0 ):
706660 """
707661 Return the indices for the lower-triangle of arr.
708662 See `tril_indices` for full details.
@@ -717,10 +671,9 @@ def triu_indices_from(arr, k=0):
717671 Diagonal offset (see `tril` for details).
718672 """
719673
720- is_arr_dparray = isinstance (arr , dparray )
721-
722- if (not use_origin_backend (arr ) and is_arr_dparray ):
674+ x1_desc = dpnp .get_dpnp_descriptor (x1 )
675+ if x1_desc :
723676 if isinstance (k , int ):
724- return dpnp_triu_indices_from (arr , k )
677+ return dpnp_triu_indices_from (x1_desc , k )
725678
726- return call_origin (numpy .triu_indices_from , arr , k )
679+ return call_origin (numpy .triu_indices_from , x1 , k )
0 commit comments