Skip to content

Commit

Permalink
Added torch.isTypeOf after discussion with Koray
Browse files Browse the repository at this point in the history
  • Loading branch information
malcolmreynolds committed Oct 14, 2014
1 parent 8e21cd2 commit cf1d17c
Show file tree
Hide file tree
Showing 4 changed files with 77 additions and 0 deletions.
13 changes: 13 additions & 0 deletions doc/tensor.md
Original file line number Diff line number Diff line change
Expand Up @@ -516,6 +516,19 @@ Convenience method for the [type](#torch.Tensor.type) method. Equivalent to
type(tensor:type())
```

<a name="torch.Tensor.isTensor"?>
### [boolean] isTensor(object) ###

Returns `true` iff the provided `object` is one of the `torch.*Tensor` types.

```lua
> =torch.isTensor(torch.randn(3,4))
true
> =torch.isTensor(torch.randn(3,4)[1])
true
> =torch.isTensor(torch.randn(3,4)[1][2])
false
```

<a name="torch.Tensor.byte"/>
### [Tensor] byte(), char(), short(), int(), long(), float(), double() ###
Expand Down
10 changes: 10 additions & 0 deletions doc/utility.md
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,16 @@ Returns `nil` if `object` is not a Torch object.

This is different from the _object_ id returned by [torch.pointer()](#torch.pointer).

<a name="torch.isTypeOf"/>
### [boolean] isTypeOf(object, typeSpec) ###

Checks if a given object is an instance of the type specified by typeSpec.
Typespec can be a string (including a string.find pattern) or the constructor
object for a Torch class. This function traverses up the class hierarchy,
so if b is an instance of B which is a subclass of A, then
`torch.isTypeOf(b, B)` and `torch.isTypeOf(b, A)` will both return true.


<a name="torch.newmetatable"/>
### [table] torch.newmetatable(name, parentName, constructor) ###

Expand Down
27 changes: 27 additions & 0 deletions init.lua
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,27 @@ function torch.type(obj)
return class
end

--[[ See if a given object is an instance of the provided torch class. ]]
function torch.isTypeOf(obj, typeSpec)
-- typeSpec can be provided as either a string, regexp, or the constructor. If
-- the constructor is used, we look in the __typename field of the
-- metatable to find a string to compare to.
if type(typeSpec) ~= 'string' then
typeSpec = getmetatable(typeSpec).__typename
assert(type(typeSpec) == 'string',
"type must be provided as [regexp] string, or factory")
end

local mt = getmetatable(obj)
while mt do
if mt.__typename and mt.__typename:find(typeSpec) then
return true
end
mt = getmetatable(mt)
end
return false
end

torch.setdefaulttensortype('torch.DoubleTensor')

include('Tensor.lua')
Expand All @@ -93,4 +114,10 @@ include('FFI.lua')
include('Tester.lua')
include('test.lua')

function torch.isTensor(obj)
return torch.isTypeOf(obj, 'torch.*Tensor')
end
-- alias for convenience
torch.Tensor.isTensor = torch.isTensor

return torch
27 changes: 27 additions & 0 deletions test/test.lua
Original file line number Diff line number Diff line change
Expand Up @@ -1403,6 +1403,33 @@ function torchtest.type()
end
end

function torchtest.isTypeOfInheritance()
do
local A = torch.class('A')
local B, parB = torch.class('B', 'A')
local C, parC = torch.class('C', 'A')
end
local a, b, c = A(), B(), C()

mytester:assert(torch.isTypeOf(a, 'A'), 'isTypeOf error, string spec')
mytester:assert(torch.isTypeOf(a, A), 'isTypeOf error, constructor')
mytester:assert(torch.isTypeOf(b, 'B'), 'isTypeOf error child class')
mytester:assert(torch.isTypeOf(b, B), 'isTypeOf error child class ctor')
mytester:assert(torch.isTypeOf(b, 'A'), 'isTypeOf error: inheritance')
mytester:assert(torch.isTypeOf(b, A), 'isTypeOf error: inheritance')
mytester:assert(not torch.isTypeOf(c, 'B'), 'isTypeOf error: common parent')
mytester:assert(not torch.isTypeOf(c, B), 'isTypeOf error: common parent')
end


function torchtest.isTensor()
local t = torch.randn(3,4)
mytester:assert(torch.isTensor(t), 'error in isTensor')
mytester:assert(torch.isTensor(t[1]), 'error in isTensor for subTensor')
mytester:assert(not torch.isTensor(t[1][2]), 'false positive in isTensor')
mytester:assert(torch.Tensor.isTensor(t), 'alias not working')
end

function torchtest.view()
local tensor = torch.rand(15)
local template = torch.rand(3,5)
Expand Down

0 comments on commit cf1d17c

Please sign in to comment.