libwt  1.0.0
A C++ library for the continous wavelet transform.
convolution.hh
1 #ifndef __WT_CONVOLUTION_HH__
2 #define __WT_CONVOLUTION_HH__
3 
4 #include "types.hh"
5 #include "fft.hh"
6 
7 namespace wt {
8 
20 {
21 public:
24  Convolution(const CMatrix &kernels, size_t subSample = 1);
26  virtual ~Convolution();
27 
31  template <class iDerived, class oDerived>
32  void apply(const Eigen::DenseBase<iDerived> &signal, Eigen::DenseBase<oDerived> &out)
33  {
34  // First, clear _lastRes matrix
35  this->_lastRes.setConstant(0);
36 
37  /*
38  * Perform overlap-add FFT
39  */
40  size_t N = signal.size();
41  size_t steps = N/this->_M;
42  size_t rem = N%this->_M;
43  size_t out_offset = 0;
44 
45  // Compute the first complete steps
46  for (size_t i=0; i<steps; i++) {
47  // Store piece into forward-trafo buffer
48  this->_part.head(this->_M).noalias() =
49  signal.block(i*this->_M,0, _M,1).template cast<CScalar>();
50  // 0-pad
51  this->_part.tail(this->_M).setConstant(0);
52 
53  // perform forward FFT
54  this->_fwd.exec();
55 
56  // Multiply result of forward FFT of the signal piece with every (transformed) kernel
57  for (size_t j=0; j<this->_K; j++) {
58  this->_work.col(j).noalias() =
59  this->_part.cwiseProduct(this->_kernelF.col(j));
60  }
61 
62  // Peform backward trafo
63  this->_rev.exec();
64 
65  /*
66  * Compute result of convolution and store it into the output buffer
67  */
68  if (0 == out_offset) { // first block
69  out.topRows(this->_M/2).noalias() =
70  ( ( this->_work.block(this->_M/2, 0, this->_M/2, this->_K)
71  + this->_lastRes.topRows(this->_M/2) )/(2*this->_M) );
72  out_offset += ( this->_M/2 );
73  } else { // intermediate blocks
74  out.block(out_offset, 0, this->_M, this->_K).noalias() =
75  ( ( this->_work.topRows(this->_M) + this->_lastRes)/(2*this->_M) );
76  out_offset += this->_M;
77  }
78  // Store remaining part for next step unless we are at the last step
79  this->_lastRes.noalias() = this->_work.bottomRows(this->_M);
80  }
81 
82  // Perform final step (if rem>0)
83  if (rem) {
84  // Store remaining samples
85  this->_part.head(rem).noalias() = signal.block(steps*this->_M,0, rem,1);
86  // 0-pad
87  this->_part.tail(2*this->_M-rem).setConstant(0);
88 
89  // perform forward FFT
90  this->_fwd.exec();
91 
92  // Multiply result of forward FFT of the signal piece with every (transformed) kernel
93  for (size_t j=0; j<this->_K; j++) {
94  this->_work.col(j).noalias() =
95  this->_part.cwiseProduct(this->_kernelF.col(j));
96  }
97 
98  // Peform backward trafo
99  this->_rev.exec();
100 
101  // store result if _M/2+rem < _M
102  if (this->_M >= (rem+this->_M/2)) {
103  out.block(out_offset, 0, rem+this->_M/2, this->_K).noalias() =
104  ( ( this->_work.topRows(rem+this->_M/2) +
105  this->_lastRes.topRows(rem+this->_M/2) ) / (2*this->_M) );
106  } else {
107  out.block(out_offset, 0, this->_M, this->_K).noalias() =
108  ( ( this->_work.topRows(this->_M) +
109  this->_lastRes.topRows(this->_M) ) / (2*this->_M) );
110  out_offset += this->_M;
111  size_t n = rem+this->_M/2-this->_M;
112  out.block(out_offset, 0, n, this->_K).noalias() =
113  (this->_work.block(this->_M, 0, n, this->_K) / (2*this->_M));
114  }
115  } else {
116  // store last _M/2 samples if rem==0
117  out.block(out_offset, 0, this->_M/2, this->_K).noalias() =
118  ( this->_work.block(this->_M-rem, 0, this->_M/2, this->_K) / (2*this->_M) );
119  }
120  }
121 
123  inline size_t kernelLength() const { return this->_M; }
125  inline size_t numKernels() const { return this->_K; }
126 
128  inline size_t subSampling() const { return _subSampling; }
130  void setSubSampling(size_t subSample) { _subSampling = subSample; }
131 
132 protected:
134  size_t _K;
136  size_t _M;
138  CMatrix _kernelF;
139 
141  CVector _part;
144 
146  CMatrix _lastRes;
148  CMatrix _work;
151 
156  size_t _subSampling;
157 };
158 
159 }
160 
161 #endif // __WT_CONVOLUTION_HH__
void exec()
Executes the FFT.
Definition: fft_fftw3.cc:64
FFT _rev
Backward transformation.
Definition: convolution.hh:150
Convolution(const CMatrix &kernels, size_t subSample=1)
Constructor.
Definition: convolution.cc:6
CMatrix _lastRes
Second halfs of the back-transformed, filtered singals.
Definition: convolution.hh:146
Implements the generic FFT-plan interface for the FFTW3 library.
Definition: fft_fftw3.hh:10
CVector _part
Working vector for the forward-transform of an piece of the input signal.
Definition: convolution.hh:141
size_t kernelLength() const
Returns the length of the kernels.
Definition: convolution.hh:123
virtual ~Convolution()
Destructor.
Definition: convolution.cc:19
CMatrix _kernelF
Holds the Fourier transformed of the kernels.
Definition: convolution.hh:138
size_t numKernels() const
Returns the number of kernels.
Definition: convolution.hh:125
size_t _K
The number of kernels.
Definition: convolution.hh:134
CMatrix _work
Working memory for backward transformation of filtered signals.
Definition: convolution.hh:148
Definition: convolution.hh:7
size_t _M
The lenght of the kernels.
Definition: convolution.hh:136
FFT _fwd
The in-place FFT transform of a (zero-padded) signal part.
Definition: convolution.hh:143
Implements the overlap-add covolution of a signal with several filter kernels of the same size...
Definition: convolution.hh:19
void setSubSampling(size_t subSample)
Sets the sub-sampling assinged to the convolution operation.
Definition: convolution.hh:130
size_t subSampling() const
Returns the sub-sampling assinged to the convolution operation.
Definition: convolution.hh:128
void apply(const Eigen::DenseBase< iDerived > &signal, Eigen::DenseBase< oDerived > &out)
Performs the convolution of the signal passed by signal with the kernels passed to the constructor...
Definition: convolution.hh:32
size_t _subSampling
Possible subsampling for the kernels.
Definition: convolution.hh:156