Skip to content

Commit

Permalink
Add promoteTypes to ATen and torch._promote_types to python. (#5795)
Browse files Browse the repository at this point in the history
This isn't hooked up to anything yet, but is necessary for both scalar binary ops in ATen and tensor constructor type inference in PyTorch.
  • Loading branch information
gchanan committed Mar 15, 2018
1 parent 8277781 commit 6f5e869
Show file tree
Hide file tree
Showing 2 changed files with 64 additions and 4 deletions.
45 changes: 41 additions & 4 deletions aten/src/ATen/ScalarType.h
Expand Up @@ -11,12 +11,12 @@ namespace at {
#define AT_FORALL_SCALAR_TYPES(_) \
_(uint8_t,Byte,i) \
_(int8_t,Char,i) \
_(double,Double,d) \
_(float,Float,d) \
_(int16_t,Short,i) \
_(int,Int,i) \
_(int64_t,Long,i) \
_(int16_t,Short,i) \
_(Half,Half,d)
_(Half,Half,d) \
_(float,Float,d) \
_(double,Double,d)

enum class ScalarType {
#define DEFINE_ENUM(_1,n,_2) \
Expand Down Expand Up @@ -103,6 +103,43 @@ static inline bool isFloatingType(ScalarType t) {
t == ScalarType::Half);
}

static inline ScalarType promoteTypes(ScalarType a, ScalarType b) {
// This is generated according to NumPy's promote_types
#define u1 ScalarType::Byte
#define i1 ScalarType::Char
#define i2 ScalarType::Short
#define i4 ScalarType::Int
#define i8 ScalarType::Long
#define f2 ScalarType::Half
#define f4 ScalarType::Float
#define f8 ScalarType::Double
#define ud ScalarType::Undefined
static constexpr ScalarType _promoteTypesLookup
[static_cast<int>(ScalarType::NumOptions)]
[static_cast<int>(ScalarType::NumOptions)] = {
/* u1 i1 i2 i4 i8 f2 f4 f8, ud */
/* u1 */ { u1, i2, i2, i4, i8, f2, f4, f8, ud },
/* i1 */ { i2, i1, i2, i4, i8, f2, f4, f8, ud },
/* i2 */ { i2, i2, i2, i4, i8, f4, f4, f8, ud },
/* i4 */ { i4, i4, i4, i4, i8, f8, f8, f8, ud },
/* i8 */ { i8, i8, i8, i8, i8, f8, f8, f8, ud },
/* f2 */ { f2, f2, f4, f8, f8, f2, f4, f8, ud },
/* f4 */ { f4, f4, f4, f8, f8, f4, f4, f8, ud },
/* f8 */ { f8, f8, f8, f8, f8, f8, f8, f8, ud },
/* ud */ { ud, ud, ud, ud, ud, ud, ud, ud, ud },
};
#undef u1
#undef i1
#undef i2
#undef i4
#undef i8
#undef f2
#undef f4
#undef f8
#undef ud
return _promoteTypesLookup[static_cast<int>(a)][static_cast<int>(b)];
}

struct Tensor;
typedef ArrayRef<int64_t> IntList;
typedef ArrayRef<Tensor> TensorList;
Expand Down
23 changes: 23 additions & 0 deletions tools/autograd/templates/python_torch_functions.cpp
Expand Up @@ -82,6 +82,28 @@ static PyObject * THPVariable_from_numpy(PyObject* module, PyObject* arg)
END_HANDLE_TH_ERRORS
}

static PyObject * THPVariable__promote_types(PyObject* self, PyObject* args, PyObject* kwargs)
{
HANDLE_TH_ERRORS
static PythonArgParser parser({
"_promote_types(Type type1, Type type2)",
});
ParsedArgs<2> parsed_args;
auto r = parser.parse(args, kwargs, parsed_args);
if (r.idx == 0) {
const at::Type& t1 = r.type(0);
const at::Type& t2 = r.type(1);
if (t1.backend() != t2.backend()) {
at::runtime_error("_promote_types only supports types with the same backends. Got %s and %s.",
at::toString(t1.backend()), at::toString(t2.backend()));
}
ScalarType promoted = at::promoteTypes(t1.scalarType(), t2.scalarType());
return torch::autograd::utils::wrap(torch::getDtype(t1.backend(), promoted));
}
Py_RETURN_NONE;
END_HANDLE_TH_ERRORS
}

static PyObject * THPVariable_tensor(PyObject* self, PyObject* args, PyObject* kwargs)
{
HANDLE_TH_ERRORS
Expand All @@ -100,6 +122,7 @@ static PyMethodDef torch_functions[] = {
{"hsmm", (PyCFunction)THPVariable_hspmm, METH_VARARGS | METH_KEYWORDS | METH_STATIC, NULL},
{"saddmm", (PyCFunction)THPVariable_sspaddmm, METH_VARARGS | METH_KEYWORDS | METH_STATIC, NULL},
{"spmm", (PyCFunction)THPVariable_mm, METH_VARARGS | METH_KEYWORDS | METH_STATIC, NULL},
{"_promote_types", (PyCFunction)THPVariable__promote_types, METH_VARARGS | METH_KEYWORDS | METH_STATIC, NULL},
{"tensor", (PyCFunction)THPVariable_tensor, METH_VARARGS | METH_KEYWORDS | METH_STATIC, NULL},
${py_method_defs}
{NULL}
Expand Down

0 comments on commit 6f5e869

Please sign in to comment.