Eigen-unsupported  5.0.1-dev+284dcc12
 
Loading...
Searching...
No Matches
TensorDevice.h
1// This file is part of Eigen, a lightweight C++ template library
2// for linear algebra.
3//
4// Copyright (C) 2014 Benoit Steiner <benoit.steiner.goog@gmail.com>
5//
6// This Source Code Form is subject to the terms of the Mozilla
7// Public License v. 2.0. If a copy of the MPL was not distributed
8// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
9
10#ifndef EIGEN_CXX11_TENSOR_TENSOR_DEVICE_H
11#define EIGEN_CXX11_TENSOR_TENSOR_DEVICE_H
12
13// IWYU pragma: private
14#include "./InternalHeaderCheck.h"
15
16namespace Eigen {
17
29template <typename ExpressionType, typename DeviceType>
30class TensorDevice {
31 public:
32 TensorDevice(const DeviceType& device, ExpressionType& expression) : m_device(device), m_expression(expression) {}
33
34 EIGEN_DEFAULT_COPY_CONSTRUCTOR(TensorDevice)
35
36 template <typename OtherDerived>
37 EIGEN_STRONG_INLINE TensorDevice& operator=(const OtherDerived& other) {
39 Assign assign(m_expression, other);
40 internal::TensorExecutor<const Assign, DeviceType>::run(assign, m_device);
41 return *this;
42 }
43
44 template <typename OtherDerived>
45 EIGEN_STRONG_INLINE TensorDevice& operator+=(const OtherDerived& other) {
46 typedef typename OtherDerived::Scalar Scalar;
47 typedef TensorCwiseBinaryOp<internal::scalar_sum_op<Scalar>, const ExpressionType, const OtherDerived> Sum;
48 Sum sum(m_expression, other);
50 Assign assign(m_expression, sum);
51 internal::TensorExecutor<const Assign, DeviceType>::run(assign, m_device);
52 return *this;
53 }
54
55 template <typename OtherDerived>
56 EIGEN_STRONG_INLINE TensorDevice& operator-=(const OtherDerived& other) {
57 typedef typename OtherDerived::Scalar Scalar;
58 typedef TensorCwiseBinaryOp<internal::scalar_difference_op<Scalar>, const ExpressionType, const OtherDerived>
59 Difference;
60 Difference difference(m_expression, other);
62 Assign assign(m_expression, difference);
63 internal::TensorExecutor<const Assign, DeviceType>::run(assign, m_device);
64 return *this;
65 }
66
67 protected:
68 const DeviceType& m_device;
69 ExpressionType& m_expression;
70};
71
85
86template <typename ExpressionType, typename DeviceType, typename DoneCallback>
87class TensorAsyncDevice {
88 public:
89 TensorAsyncDevice(const DeviceType& device, ExpressionType& expression, DoneCallback done)
90 : m_device(device), m_expression(expression), m_done(std::move(done)) {}
91
92 template <typename OtherDerived>
93 EIGEN_STRONG_INLINE TensorAsyncDevice& operator=(const OtherDerived& other) {
96
97 Assign assign(m_expression, other);
98 Executor::run(assign, m_device);
99 m_done();
100
101 return *this;
102 }
103
104 protected:
105 const DeviceType& m_device;
106 ExpressionType& m_expression;
107 DoneCallback m_done;
108};
109
110#ifdef EIGEN_USE_THREADS
111template <typename ExpressionType, typename DoneCallback>
112class TensorAsyncDevice<ExpressionType, ThreadPoolDevice, DoneCallback> {
113 public:
114 TensorAsyncDevice(const ThreadPoolDevice& device, ExpressionType& expression, DoneCallback done)
115 : m_device(device), m_expression(expression), m_done(std::move(done)) {}
116
117 template <typename OtherDerived>
118 EIGEN_STRONG_INLINE TensorAsyncDevice& operator=(const OtherDerived& other) {
119 typedef TensorAssignOp<ExpressionType, const OtherDerived> Assign;
120 typedef internal::TensorAsyncExecutor<const Assign, ThreadPoolDevice, DoneCallback> Executor;
121
122 // WARNING: After assignment 'm_done' callback will be in undefined state.
123 Assign assign(m_expression, other);
124 Executor::runAsync(assign, m_device, std::move(m_done));
125
126 return *this;
127 }
128
129 protected:
130 const ThreadPoolDevice& m_device;
131 ExpressionType& m_expression;
132 DoneCallback m_done;
133};
134#endif
135
136} // end namespace Eigen
137
138#endif // EIGEN_CXX11_TENSOR_TENSOR_DEVICE_H
Definition TensorAssign.h:55
Pseudo expression providing an operator = that will evaluate its argument asynchronously on the speci...
Definition TensorDevice.h:87
Tensor binary expression.
Definition TensorExpr.h:171
The tensor executor class.
Definition TensorExecutor.h:76
Namespace containing all symbols from the Eigen library.