10#ifndef EIGEN_CXX11_TENSOR_TENSOR_CUSTOM_OP_H
11#define EIGEN_CXX11_TENSOR_TENSOR_CUSTOM_OP_H
16template<
typename CustomUnaryFunc,
typename XprType>
17struct traits<TensorCustomUnaryOp<CustomUnaryFunc, XprType> >
19 typedef typename XprType::Scalar Scalar;
20 typedef typename XprType::StorageKind StorageKind;
21 typedef typename XprType::Index
Index;
22 typedef typename XprType::Nested Nested;
23 typedef typename remove_reference<Nested>::type _Nested;
24 static const int NumDimensions = traits<XprType>::NumDimensions;
25 static const int Layout = traits<XprType>::Layout;
26 typedef typename traits<XprType>::PointerType PointerType;
29template<
typename CustomUnaryFunc,
typename XprType>
30struct eval<TensorCustomUnaryOp<CustomUnaryFunc, XprType>, Eigen::Dense>
32 typedef const TensorCustomUnaryOp<CustomUnaryFunc, XprType>EIGEN_DEVICE_REF type;
35template<
typename CustomUnaryFunc,
typename XprType>
36struct nested<TensorCustomUnaryOp<CustomUnaryFunc, XprType> >
38 typedef TensorCustomUnaryOp<CustomUnaryFunc, XprType> type;
48template <
typename CustomUnaryFunc,
typename XprType>
49class TensorCustomUnaryOp :
public TensorBase<TensorCustomUnaryOp<CustomUnaryFunc, XprType>, ReadOnlyAccessors> {
51 typedef typename internal::traits<TensorCustomUnaryOp>::Scalar Scalar;
53 typedef typename XprType::CoeffReturnType CoeffReturnType;
54 typedef typename internal::nested<TensorCustomUnaryOp>::type Nested;
55 typedef typename internal::traits<TensorCustomUnaryOp>::StorageKind StorageKind;
56 typedef typename internal::traits<TensorCustomUnaryOp>::Index Index;
58 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorCustomUnaryOp(
const XprType& expr,
const CustomUnaryFunc& func)
59 : m_expr(expr), m_func(func) {}
62 const CustomUnaryFunc& func()
const {
return m_func; }
65 const typename internal::remove_all<typename XprType::Nested>::type&
66 expression()
const {
return m_expr; }
69 typename XprType::Nested m_expr;
70 const CustomUnaryFunc m_func;
75template<
typename CustomUnaryFunc,
typename XprType,
typename Device>
79 typedef typename internal::traits<ArgType>::Index
Index;
80 static const int NumDims = internal::traits<ArgType>::NumDimensions;
82 typedef typename internal::remove_const<typename ArgType::Scalar>::type
Scalar;
83 typedef typename internal::remove_const<typename XprType::CoeffReturnType>::type
CoeffReturnType;
84 typedef typename PacketType<CoeffReturnType, Device>::type PacketReturnType;
85 static const int PacketSize = PacketType<CoeffReturnType, Device>::size;
86 typedef typename Eigen::internal::traits<XprType>::PointerType TensorPointerType;
87 typedef StorageMemory<CoeffReturnType, Device> Storage;
88 typedef typename Storage::Type EvaluatorPointerType;
92 PacketAccess = (PacketType<CoeffReturnType, Device>::size > 1),
94 PreferBlockAccess =
false,
95 Layout = TensorEvaluator<XprType, Device>::Layout,
101 typedef internal::TensorBlockNotImplemented TensorBlock;
104 EIGEN_STRONG_INLINE TensorEvaluator(
const ArgType& op,
const Device& device)
105 : m_op(op), m_device(device), m_result(NULL)
107 m_dimensions = op.func().dimensions(op.expression());
110 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
const Dimensions& dimensions()
const {
return m_dimensions; }
112 EIGEN_STRONG_INLINE
bool evalSubExprsIfNeeded(EvaluatorPointerType data) {
117 m_result =
static_cast<EvaluatorPointerType
>(m_device.get( (CoeffReturnType*)
118 m_device.allocate_temp(dimensions().TotalSize() *
sizeof(Scalar))));
124 EIGEN_STRONG_INLINE
void cleanup() {
126 m_device.deallocate_temp(m_result);
131 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE CoeffReturnType coeff(Index index)
const {
132 return m_result[index];
135 template<
int LoadMode>
136 EIGEN_DEVICE_FUNC PacketReturnType packet(Index index)
const {
137 return internal::ploadt<PacketReturnType, LoadMode>(m_result + index);
140 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorOpCost costPerCoeff(
bool vectorized)
const {
142 return TensorOpCost(
sizeof(CoeffReturnType), 0, 0, vectorized, PacketSize);
145 EIGEN_DEVICE_FUNC EvaluatorPointerType data()
const {
return m_result; }
149 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
void bind(cl::sycl::handler &cgh)
const {
155 void evalTo(EvaluatorPointerType data) {
156 TensorMap<Tensor<CoeffReturnType, NumDims, Layout, Index> > result(m_device.get(data), m_dimensions);
157 m_op.func().eval(m_op.expression(), result, m_device);
160 Dimensions m_dimensions;
162 const Device EIGEN_DEVICE_REF m_device;
163 EvaluatorPointerType m_result;
176template<
typename CustomBinaryFunc,
typename LhsXprType,
typename RhsXprType>
177struct traits<TensorCustomBinaryOp<CustomBinaryFunc, LhsXprType, RhsXprType> >
179 typedef typename internal::promote_storage_type<
typename LhsXprType::Scalar,
180 typename RhsXprType::Scalar>::ret Scalar;
181 typedef typename internal::promote_storage_type<
typename LhsXprType::CoeffReturnType,
182 typename RhsXprType::CoeffReturnType>::ret CoeffReturnType;
183 typedef typename promote_storage_type<typename traits<LhsXprType>::StorageKind,
184 typename traits<RhsXprType>::StorageKind>::ret StorageKind;
185 typedef typename promote_index_type<typename traits<LhsXprType>::Index,
186 typename traits<RhsXprType>::Index>::type
Index;
187 typedef typename LhsXprType::Nested LhsNested;
188 typedef typename RhsXprType::Nested RhsNested;
189 typedef typename remove_reference<LhsNested>::type _LhsNested;
190 typedef typename remove_reference<RhsNested>::type _RhsNested;
191 static const int NumDimensions = traits<LhsXprType>::NumDimensions;
192 static const int Layout = traits<LhsXprType>::Layout;
193 typedef typename conditional<Pointer_type_promotion<typename LhsXprType::Scalar, Scalar>::val,
194 typename traits<LhsXprType>::PointerType,
typename traits<RhsXprType>::PointerType>::type PointerType;
197template<
typename CustomBinaryFunc,
typename LhsXprType,
typename RhsXprType>
198struct eval<TensorCustomBinaryOp<CustomBinaryFunc, LhsXprType, RhsXprType>, Eigen::Dense>
200 typedef const TensorCustomBinaryOp<CustomBinaryFunc, LhsXprType, RhsXprType>& type;
203template<
typename CustomBinaryFunc,
typename LhsXprType,
typename RhsXprType>
204struct nested<TensorCustomBinaryOp<CustomBinaryFunc, LhsXprType, RhsXprType> >
206 typedef TensorCustomBinaryOp<CustomBinaryFunc, LhsXprType, RhsXprType> type;
213template<
typename CustomBinaryFunc,
typename LhsXprType,
typename RhsXprType>
214class TensorCustomBinaryOp :
public TensorBase<TensorCustomBinaryOp<CustomBinaryFunc, LhsXprType, RhsXprType>, ReadOnlyAccessors>
217 typedef typename internal::traits<TensorCustomBinaryOp>::Scalar Scalar;
219 typedef typename internal::traits<TensorCustomBinaryOp>::CoeffReturnType CoeffReturnType;
220 typedef typename internal::nested<TensorCustomBinaryOp>::type Nested;
221 typedef typename internal::traits<TensorCustomBinaryOp>::StorageKind StorageKind;
222 typedef typename internal::traits<TensorCustomBinaryOp>::Index Index;
224 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorCustomBinaryOp(
const LhsXprType& lhs,
const RhsXprType& rhs,
const CustomBinaryFunc& func)
226 : m_lhs_xpr(lhs), m_rhs_xpr(rhs), m_func(func) {}
229 const CustomBinaryFunc& func()
const {
return m_func; }
232 const typename internal::remove_all<typename LhsXprType::Nested>::type&
233 lhsExpression()
const {
return m_lhs_xpr; }
236 const typename internal::remove_all<typename RhsXprType::Nested>::type&
237 rhsExpression()
const {
return m_rhs_xpr; }
240 typename LhsXprType::Nested m_lhs_xpr;
241 typename RhsXprType::Nested m_rhs_xpr;
242 const CustomBinaryFunc m_func;
247template<
typename CustomBinaryFunc,
typename LhsXprType,
typename RhsXprType,
typename Device>
251 typedef typename internal::traits<XprType>::Index
Index;
252 static const int NumDims = internal::traits<XprType>::NumDimensions;
254 typedef typename XprType::Scalar
Scalar;
255 typedef typename internal::remove_const<typename XprType::CoeffReturnType>::type
CoeffReturnType;
256 typedef typename PacketType<CoeffReturnType, Device>::type PacketReturnType;
257 static const int PacketSize = PacketType<CoeffReturnType, Device>::size;
259 typedef typename Eigen::internal::traits<XprType>::PointerType TensorPointerType;
260 typedef StorageMemory<CoeffReturnType, Device> Storage;
261 typedef typename Storage::Type EvaluatorPointerType;
265 PacketAccess = (PacketType<CoeffReturnType, Device>::size > 1),
267 PreferBlockAccess =
false,
268 Layout = TensorEvaluator<LhsXprType, Device>::Layout,
274 typedef internal::TensorBlockNotImplemented TensorBlock;
277 EIGEN_STRONG_INLINE TensorEvaluator(
const XprType& op,
const Device& device)
278 : m_op(op), m_device(device), m_result(NULL)
280 m_dimensions = op.func().dimensions(op.lhsExpression(), op.rhsExpression());
283 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
const Dimensions& dimensions()
const {
return m_dimensions; }
285 EIGEN_STRONG_INLINE
bool evalSubExprsIfNeeded(EvaluatorPointerType data) {
290 m_result =
static_cast<EvaluatorPointerType
>(m_device.get( (CoeffReturnType*)
291 m_device.allocate_temp(dimensions().TotalSize() *
sizeof(CoeffReturnType))));
297 EIGEN_STRONG_INLINE
void cleanup() {
298 if (m_result != NULL) {
299 m_device.deallocate_temp(m_result);
304 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE CoeffReturnType coeff(Index index)
const {
305 return m_result[index];
308 template<
int LoadMode>
309 EIGEN_DEVICE_FUNC PacketReturnType packet(Index index)
const {
310 return internal::ploadt<PacketReturnType, LoadMode>(m_result + index);
313 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorOpCost costPerCoeff(
bool vectorized)
const {
315 return TensorOpCost(
sizeof(CoeffReturnType), 0, 0, vectorized, PacketSize);
318 EIGEN_DEVICE_FUNC EvaluatorPointerType data()
const {
return m_result; }
322 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
void bind(cl::sycl::handler &cgh)
const {
328 void evalTo(EvaluatorPointerType data) {
329 TensorMap<Tensor<CoeffReturnType, NumDims, Layout> > result(m_device.get(data), m_dimensions);
330 m_op.func().eval(m_op.lhsExpression(), m_op.rhsExpression(), result, m_device);
333 Dimensions m_dimensions;
335 const Device EIGEN_DEVICE_REF m_device;
336 EvaluatorPointerType m_result;
The tensor base class.
Definition TensorForwardDeclarations.h:56
Tensor custom class.
Definition TensorCustomOp.h:215
Tensor custom class.
Definition TensorCustomOp.h:49
Namespace containing all symbols from the Eigen library.
EIGEN_DEFAULT_DENSE_INDEX_TYPE Index
The tensor evaluator class.
Definition TensorEvaluator.h:27