10#ifndef EIGEN_CXX11_TENSOR_TENSOR_PATCH_H
11#define EIGEN_CXX11_TENSOR_TENSOR_PATCH_H
16template<
typename PatchDim,
typename XprType>
17struct traits<TensorPatchOp<PatchDim, XprType> > :
public traits<XprType>
19 typedef typename XprType::Scalar Scalar;
20 typedef traits<XprType> XprTraits;
21 typedef typename XprTraits::StorageKind StorageKind;
22 typedef typename XprTraits::Index
Index;
23 typedef typename XprType::Nested Nested;
24 typedef typename remove_reference<Nested>::type _Nested;
25 static const int NumDimensions = XprTraits::NumDimensions + 1;
26 static const int Layout = XprTraits::Layout;
27 typedef typename XprTraits::PointerType PointerType;
30template<
typename PatchDim,
typename XprType>
31struct eval<TensorPatchOp<PatchDim, XprType>, Eigen::Dense>
33 typedef const TensorPatchOp<PatchDim, XprType>& type;
36template<
typename PatchDim,
typename XprType>
37struct nested<TensorPatchOp<PatchDim, XprType>, 1, typename eval<TensorPatchOp<PatchDim, XprType> >::type>
39 typedef TensorPatchOp<PatchDim, XprType> type;
49template <
typename PatchDim,
typename XprType>
50class TensorPatchOp :
public TensorBase<TensorPatchOp<PatchDim, XprType>, ReadOnlyAccessors> {
52 typedef typename Eigen::internal::traits<TensorPatchOp>::Scalar Scalar;
54 typedef typename XprType::CoeffReturnType CoeffReturnType;
55 typedef typename Eigen::internal::nested<TensorPatchOp>::type Nested;
56 typedef typename Eigen::internal::traits<TensorPatchOp>::StorageKind StorageKind;
57 typedef typename Eigen::internal::traits<TensorPatchOp>::Index Index;
59 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorPatchOp(
const XprType& expr,
const PatchDim& patch_dims)
60 : m_xpr(expr), m_patch_dims(patch_dims) {}
63 const PatchDim& patch_dims()
const {
return m_patch_dims; }
66 const typename internal::remove_all<typename XprType::Nested>::type&
67 expression()
const {
return m_xpr; }
70 typename XprType::Nested m_xpr;
71 const PatchDim m_patch_dims;
76template<
typename PatchDim,
typename ArgType,
typename Device>
80 typedef typename XprType::Index
Index;
81 static const int NumDims = internal::array_size<typename TensorEvaluator<ArgType, Device>::Dimensions>::value + 1;
83 typedef typename XprType::Scalar
Scalar;
85 typedef typename PacketType<CoeffReturnType, Device>::type PacketReturnType;
86 static const int PacketSize = PacketType<CoeffReturnType, Device>::size;
87 typedef StorageMemory<CoeffReturnType, Device> Storage;
88 typedef typename Storage::Type EvaluatorPointerType;
93 PacketAccess = TensorEvaluator<ArgType, Device>::PacketAccess,
95 PreferBlockAccess = TensorEvaluator<ArgType, Device>::PreferBlockAccess,
96 Layout = TensorEvaluator<ArgType, Device>::Layout,
102 typedef internal::TensorBlockNotImplemented TensorBlock;
105 EIGEN_STRONG_INLINE TensorEvaluator(
const XprType& op,
const Device& device)
106 : m_impl(op.expression(), device)
108 Index num_patches = 1;
109 const typename TensorEvaluator<ArgType, Device>::Dimensions& input_dims = m_impl.dimensions();
110 const PatchDim& patch_dims = op.patch_dims();
111 if (
static_cast<int>(Layout) ==
static_cast<int>(
ColMajor)) {
112 for (
int i = 0; i < NumDims-1; ++i) {
113 m_dimensions[i] = patch_dims[i];
114 num_patches *= (input_dims[i] - patch_dims[i] + 1);
116 m_dimensions[NumDims-1] = num_patches;
118 m_inputStrides[0] = 1;
119 m_patchStrides[0] = 1;
120 for (
int i = 1; i < NumDims-1; ++i) {
121 m_inputStrides[i] = m_inputStrides[i-1] * input_dims[i-1];
122 m_patchStrides[i] = m_patchStrides[i-1] * (input_dims[i-1] - patch_dims[i-1] + 1);
124 m_outputStrides[0] = 1;
125 for (
int i = 1; i < NumDims; ++i) {
126 m_outputStrides[i] = m_outputStrides[i-1] * m_dimensions[i-1];
129 for (
int i = 0; i < NumDims-1; ++i) {
130 m_dimensions[i+1] = patch_dims[i];
131 num_patches *= (input_dims[i] - patch_dims[i] + 1);
133 m_dimensions[0] = num_patches;
135 m_inputStrides[NumDims-2] = 1;
136 m_patchStrides[NumDims-2] = 1;
137 for (
int i = NumDims-3; i >= 0; --i) {
138 m_inputStrides[i] = m_inputStrides[i+1] * input_dims[i+1];
139 m_patchStrides[i] = m_patchStrides[i+1] * (input_dims[i+1] - patch_dims[i+1] + 1);
141 m_outputStrides[NumDims-1] = 1;
142 for (
int i = NumDims-2; i >= 0; --i) {
143 m_outputStrides[i] = m_outputStrides[i+1] * m_dimensions[i+1];
148 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
const Dimensions& dimensions()
const {
return m_dimensions; }
150 EIGEN_STRONG_INLINE
bool evalSubExprsIfNeeded(EvaluatorPointerType ) {
151 m_impl.evalSubExprsIfNeeded(NULL);
155 EIGEN_STRONG_INLINE
void cleanup() {
159 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE CoeffReturnType coeff(Index index)
const
161 Index output_stride_index = (
static_cast<int>(Layout) ==
static_cast<int>(
ColMajor)) ? NumDims - 1 : 0;
163 Index patchIndex = index / m_outputStrides[output_stride_index];
165 Index patchOffset = index - patchIndex * m_outputStrides[output_stride_index];
166 Index inputIndex = 0;
167 if (
static_cast<int>(Layout) ==
static_cast<int>(
ColMajor)) {
169 for (
int i = NumDims - 2; i > 0; --i) {
170 const Index patchIdx = patchIndex / m_patchStrides[i];
171 patchIndex -= patchIdx * m_patchStrides[i];
172 const Index offsetIdx = patchOffset / m_outputStrides[i];
173 patchOffset -= offsetIdx * m_outputStrides[i];
174 inputIndex += (patchIdx + offsetIdx) * m_inputStrides[i];
178 for (
int i = 0; i < NumDims - 2; ++i) {
179 const Index patchIdx = patchIndex / m_patchStrides[i];
180 patchIndex -= patchIdx * m_patchStrides[i];
181 const Index offsetIdx = patchOffset / m_outputStrides[i+1];
182 patchOffset -= offsetIdx * m_outputStrides[i+1];
183 inputIndex += (patchIdx + offsetIdx) * m_inputStrides[i];
186 inputIndex += (patchIndex + patchOffset);
187 return m_impl.coeff(inputIndex);
190 template<
int LoadMode>
191 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE PacketReturnType packet(Index index)
const
193 EIGEN_STATIC_ASSERT((PacketSize > 1), YOU_MADE_A_PROGRAMMING_MISTAKE)
194 eigen_assert(index+PacketSize-1 < dimensions().TotalSize());
196 Index output_stride_index = (
static_cast<int>(Layout) ==
static_cast<int>(
ColMajor)) ? NumDims - 1 : 0;
197 Index indices[2] = {index, index + PacketSize - 1};
198 Index patchIndices[2] = {indices[0] / m_outputStrides[output_stride_index],
199 indices[1] / m_outputStrides[output_stride_index]};
200 Index patchOffsets[2] = {indices[0] - patchIndices[0] * m_outputStrides[output_stride_index],
201 indices[1] - patchIndices[1] * m_outputStrides[output_stride_index]};
203 Index inputIndices[2] = {0, 0};
204 if (
static_cast<int>(Layout) ==
static_cast<int>(
ColMajor)) {
206 for (
int i = NumDims - 2; i > 0; --i) {
207 const Index patchIdx[2] = {patchIndices[0] / m_patchStrides[i],
208 patchIndices[1] / m_patchStrides[i]};
209 patchIndices[0] -= patchIdx[0] * m_patchStrides[i];
210 patchIndices[1] -= patchIdx[1] * m_patchStrides[i];
212 const Index offsetIdx[2] = {patchOffsets[0] / m_outputStrides[i],
213 patchOffsets[1] / m_outputStrides[i]};
214 patchOffsets[0] -= offsetIdx[0] * m_outputStrides[i];
215 patchOffsets[1] -= offsetIdx[1] * m_outputStrides[i];
217 inputIndices[0] += (patchIdx[0] + offsetIdx[0]) * m_inputStrides[i];
218 inputIndices[1] += (patchIdx[1] + offsetIdx[1]) * m_inputStrides[i];
222 for (
int i = 0; i < NumDims - 2; ++i) {
223 const Index patchIdx[2] = {patchIndices[0] / m_patchStrides[i],
224 patchIndices[1] / m_patchStrides[i]};
225 patchIndices[0] -= patchIdx[0] * m_patchStrides[i];
226 patchIndices[1] -= patchIdx[1] * m_patchStrides[i];
228 const Index offsetIdx[2] = {patchOffsets[0] / m_outputStrides[i+1],
229 patchOffsets[1] / m_outputStrides[i+1]};
230 patchOffsets[0] -= offsetIdx[0] * m_outputStrides[i+1];
231 patchOffsets[1] -= offsetIdx[1] * m_outputStrides[i+1];
233 inputIndices[0] += (patchIdx[0] + offsetIdx[0]) * m_inputStrides[i];
234 inputIndices[1] += (patchIdx[1] + offsetIdx[1]) * m_inputStrides[i];
237 inputIndices[0] += (patchIndices[0] + patchOffsets[0]);
238 inputIndices[1] += (patchIndices[1] + patchOffsets[1]);
240 if (inputIndices[1] - inputIndices[0] == PacketSize - 1) {
241 PacketReturnType rslt = m_impl.template packet<Unaligned>(inputIndices[0]);
245 EIGEN_ALIGN_MAX CoeffReturnType values[PacketSize];
246 values[0] = m_impl.coeff(inputIndices[0]);
247 values[PacketSize-1] = m_impl.coeff(inputIndices[1]);
249 for (
int i = 1; i < PacketSize-1; ++i) {
250 values[i] = coeff(index+i);
252 PacketReturnType rslt = internal::pload<PacketReturnType>(values);
257 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorOpCost costPerCoeff(
bool vectorized)
const {
258 const double compute_cost = NumDims * (TensorOpCost::DivCost<Index>() +
259 TensorOpCost::MulCost<Index>() +
260 2 * TensorOpCost::AddCost<Index>());
261 return m_impl.costPerCoeff(vectorized) +
262 TensorOpCost(0, 0, compute_cost, vectorized, PacketSize);
265 EIGEN_DEVICE_FUNC EvaluatorPointerType data()
const {
return NULL; }
269 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
void bind(cl::sycl::handler &cgh)
const {
275 Dimensions m_dimensions;
276 array<Index, NumDims> m_outputStrides;
277 array<Index, NumDims-1> m_inputStrides;
278 array<Index, NumDims-1> m_patchStrides;
280 TensorEvaluator<ArgType, Device> m_impl;
The tensor base class.
Definition TensorForwardDeclarations.h:56
Tensor patch class.
Definition TensorPatch.h:50
Namespace containing all symbols from the Eigen library.
EIGEN_DEFAULT_DENSE_INDEX_TYPE Index
The tensor evaluator class.
Definition TensorEvaluator.h:27