Intel(R) Math Kernel Library for Deep Neural Networks (Intel(R) MKL-DNN)  0.17
Performance library for Deep Learning
mkldnn.hpp
Go to the documentation of this file.
1 /*******************************************************************************
2 * Copyright 2016-2018 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16 
17 #ifndef MKLDNN_HPP
18 #define MKLDNN_HPP
19 
20 #ifndef DOXYGEN_SHOULD_SKIP_THIS
21 #include <stdlib.h>
22 #include <memory>
23 #include <vector>
24 #include <algorithm>
25 #include <iterator>
26 #include <string>
27 
28 #include "mkldnn.h"
29 #endif
30 
31 namespace mkldnn {
32 
35 
38 
40 template <typename T> class handle_traits {};
41 
55 template <typename T, typename traits=handle_traits<T>> class handle {
56 private:
57  std::shared_ptr<typename std::remove_pointer<T>::type> _data;
58  handle(const handle &&) = delete;
59  handle &operator=(const handle &&other) = delete;
60 protected:
61  bool operator==(const T other) const { return other == _data.get(); }
62  bool operator!=(const T other) const { return !(*this == other); }
63 public:
67  handle(T t = 0, bool weak = false): _data(0) {
68  reset(t, weak);
69  }
70 
71  handle(const handle &other): _data(other._data) {}
72  handle &operator=(const handle &other) {
73  _data = other._data;
74  return *this;
75  }
79  void reset(T t, bool weak = false) {
80  auto dummy_destructor = [](T) { return decltype(traits::destructor(0))(0); };
81  _data.reset(t, weak ? dummy_destructor : traits::destructor);
82  }
83 
85  T get() const { return _data.get(); }
86 
87  bool operator==(const handle &other) const { return other._data.get() == _data.get(); }
88  bool operator!=(const handle &other) const { return !(*this == other); }
89 };
90 
91 #ifndef DOXYGEN_SHOULD_SKIP_THIS
92 template <> struct handle_traits<mkldnn_primitive_desc_t> {
93  static constexpr auto destructor = &mkldnn_primitive_desc_destroy;
94 };
95 
96 template <> struct handle_traits<mkldnn_primitive_t> {
97  static constexpr auto destructor = &mkldnn_primitive_destroy;
98 };
99 
100 template <> struct handle_traits<mkldnn_primitive_desc_iterator_t> {
101  static constexpr auto destructor = &mkldnn_primitive_desc_iterator_destroy;
102 };
103 #endif
104 
106 class primitive: public handle<mkldnn_primitive_t> {
107  friend struct error;
108  friend struct stream;
109  friend class primitive_at;
110  using handle::handle;
111 public:
113  enum class kind {
114  undefined_primitive = mkldnn_undefined_primitive,
116  view = mkldnn_view,
119  concat_inplace = mkldnn_concat_inplace,
120  sum = mkldnn_sum,
121  convolution = mkldnn_convolution,
122  deconvolution = mkldnn_deconvolution,
123  shuffle = mkldnn_shuffle,
124  eltwise = mkldnn_eltwise,
125  softmax = mkldnn_softmax,
126  pooling = mkldnn_pooling,
127  lrn = mkldnn_lrn,
128  batch_normalization = mkldnn_batch_normalization,
129  inner_product = mkldnn_inner_product,
130  rnn = mkldnn_rnn,
131  };
132 
134  struct at {
142 
143  at(const primitive &aprimitive, size_t at = 0)
144  : data(mkldnn_primitive_at(aprimitive.get(), at)) {}
146  inline operator primitive() const;
147  };
148 
150  inline const_mkldnn_primitive_desc_t get_primitive_desc() const;
151  // TODO: use the C++ API wrapper structure.
152 };
153 
155  return static_cast<mkldnn_primitive_kind_t>(akind);
156 }
161 struct error: public std::exception {
163  std::string message;
165 
172 
173  error(mkldnn_status_t astatus, std::string amessage,
174  mkldnn_primitive_t aerror_primitive = 0)
175  : status(astatus)
176  , message(amessage)
177  , error_primitive(aerror_primitive, true)
178  {}
179 
187 
188  static void wrap_c_api(mkldnn_status_t status,
189  const std::string &message,
190  mkldnn_primitive_t *error_primitive = 0)
191  {
192  if (status != mkldnn_success) {
193  if (nullptr != error_primitive)
194  throw error(status, message, *error_primitive);
195  else
196  throw error(status, message, nullptr);
197  }
198  }
199 };
200 
201 inline primitive::at::operator primitive() const {
204  mkldnn_primitive_get_output(data.primitive,
205  data.output_index, &output),
206  "could not get an output primitive");
207  return primitive(const_cast<mkldnn_primitive_t>(output), true);
208 }
209 
213  "could not get primitive descriptor by primitive");
214  return pd;
215 }
217 
222 
226 };
227 
229  return static_cast<mkldnn_round_mode_t>(mode);
230 }
231 
234 };
235 
237  return static_cast<mkldnn_padding_kind_t>(kind);
238 }
239 
240 enum prop_kind {
249 };
250 
252  return static_cast<mkldnn_prop_kind_t>(kind);
253 }
254 
255 enum algorithm {
282 };
283 
285  return static_cast<mkldnn_alg_kind_t>(aalgorithm);
286 }
287 
292 };
293 
295  batch_normalization_flag aflag) {
296  return static_cast<mkldnn_batch_normalization_flag_t>(aflag);
297 }
298 
305 };
306 
308  return static_cast<mkldnn_rnn_direction_t>(adir);
309 }
310 
311 enum query {
313 
316 
319 
322 
324 
337 
347 };
348 
350  return static_cast<mkldnn_query_t>(aquery);
351 }
352 
354 
360 
361 #ifndef DOXYGEN_SHOULD_SKIP_THIS
362 template <> struct handle_traits<mkldnn_post_ops_t> {
363  static constexpr auto destructor = &mkldnn_post_ops_destroy;
364 };
365 #endif
366 
367 struct post_ops: public handle<mkldnn_post_ops_t> {
369  mkldnn_post_ops_t result;
371  "could not create post operation sequence");
372  reset(result);
373  }
374 
375  int len() const { return mkldnn_post_ops_len(get()); }
376 
377  primitive::kind kind(int index) const {
379  index < len() ? mkldnn_success : mkldnn_invalid_arguments,
380  "post_ops index is out of range");
381  return static_cast<primitive::kind>(mkldnn_post_ops_get_kind(get(),
382  index));
383  }
384 
385  void append_sum(float scale = 1.) {
387  "could not append sum");
388  }
389 
390  void get_params_sum(int index, float &scale) const {
392  "could not get sum params");
393  }
394 
395  void append_eltwise(float scale, algorithm alg, float alpha,
396  float beta) {
398  convert_to_c(alg), alpha, beta),
399  "could not append eltwise");
400  }
401 
402  void get_params_eltwise(int index, float &scale, algorithm &alg,
403  float &alpha, float &beta) const {
404  mkldnn_alg_kind_t c_alg;
406  &scale, &c_alg, &alpha, &beta),
407  "could not get eltwise params");
408  alg = static_cast<algorithm>(c_alg);
409  }
410 };
411 
412 #ifndef DOXYGEN_SHOULD_SKIP_THIS
413 template <> struct handle_traits<mkldnn_primitive_attr_t> {
414  static constexpr auto destructor = &mkldnn_primitive_attr_destroy;
415 };
416 #endif
417 
418 struct primitive_attr: public handle<mkldnn_primitive_attr_t> {
420  mkldnn_primitive_attr_t result;
422  "could not create a primitive attr");
423  reset(result);
424  }
425 
427  mkldnn_round_mode_t result;
429  get(), &result), "could not get int output round mode");
430  return round_mode(result);
431  }
432 
435  get(), mkldnn::convert_to_c(mode)),
436  "could not set int output round mode");
437  }
438 
439  void get_output_scales(int &mask, std::vector<float> &scales) const
440  {
441  int count, c_mask;
442  const float *c_scales;
444  &count, &c_mask, &c_scales),
445  "could not get int output scales");
446  scales.resize(count);
447 
448  mask = c_mask;
449  for (int c = 0; c < count; ++c)
450  scales[c] = c_scales[c];
451  }
452 
453  void set_output_scales(int mask, const std::vector<float> &scales)
454  {
456  (int)scales.size(), mask, &scales[0]),
457  "could not set int output scales");
458  }
459 
460  const post_ops get_post_ops() const {
461  post_ops result;
462  const_mkldnn_post_ops_t c_result;
464  "could not get post operation sequence");
465  result.reset(const_cast<mkldnn_post_ops_t>(c_result), true);
466  return result;
467  }
468 
469  void set_post_ops(post_ops ops) {
471  "could not set post operation sequence");
472  }
473 
474  void set_rnn_data_qparams(const float scale, const float shift)
475  {
477  scale, shift), "could not set rnn data int scale/shift");
478  }
479 
480  void set_rnn_weights_qparams(int mask, const std::vector<float> &scales)
481  {
483  (int)scales.size(), mask, &scales[0]),
484  "could not set rnn weights int scales");
485  }
486 };
487 
489 
495 
496 #ifndef DOXYGEN_SHOULD_SKIP_THIS
497 template <> struct handle_traits<mkldnn_engine_t> {
498  static constexpr auto destructor = &mkldnn_engine_destroy;
499 };
500 #endif
501 
503 struct engine: public handle<mkldnn_engine_t> {
504  friend class primitive;
505  // gcc bug??? using handle::handle;
506 
508  enum kind {
512  cpu = mkldnn_cpu,
513  };
514 
518 
519  static size_t get_count(kind akind) {
520  return mkldnn_engine_get_count(convert_to_c(akind));
521  }
522 
528 
529  engine(kind akind, size_t index) {
530  mkldnn_engine_t aengine;
532  mkldnn_engine_create(&aengine,
533  convert_to_c(akind), index),
534  "could not create an engine");
535  reset(aengine);
536  }
537 
538  explicit engine(const mkldnn_engine_t& aengine)
539  : handle(aengine, true) {}
540 
542  mkldnn_engine_t engine_q;
545  mkldnn::convert_to_c(eengine), 0, &engine_q),
546  "could not get engine from primitive_desc");
547  reset(engine_q, true);
548  }
549 
550  template <class primitive_desc>
551  static engine query(const primitive_desc &pd) {
552  mkldnn_engine_t engine_q;
555  mkldnn::convert_to_c(eengine), 0, &engine_q),
556  "could not get engine from primitive_desc");
557 
558  return engine(engine_q);
559  }
560 
561 private:
562  static mkldnn_engine_kind_t convert_to_c(kind akind) {
563  return static_cast<mkldnn_engine_kind_t>(akind);
564  }
565 };
566 
568 
571 
577 
579 struct memory: public primitive {
580  private:
581  std::shared_ptr<char> _handle;
582 
583  public:
584  typedef std::vector<std::remove_extent<mkldnn_dims_t>::type> dims;
585 
586  template <typename T> static void validate_dims(std::vector<T> v) {
587  if (v.size() > TENSOR_MAX_DIMS)
589  "invalid dimensions");
590  }
591 
594  enum data_type {
596  f32 = mkldnn_f32,
597  s32 = mkldnn_s32,
598  s16 = mkldnn_s16,
599  s8 = mkldnn_s8,
600  u8 = mkldnn_u8,
601  };
602 
605  enum format {
606  format_undef = mkldnn_format_undef,
607  any = mkldnn_any,
608  blocked = mkldnn_blocked,
609  x = mkldnn_x,
610  nc = mkldnn_nc,
611  ncw = mkldnn_ncw,
612  nwc = mkldnn_nwc,
613  nCw16c = mkldnn_nCw16c,
614  nchw = mkldnn_nchw,
615  nhwc = mkldnn_nhwc,
616  chwn = mkldnn_chwn,
617  nCw8c = mkldnn_nCw8c,
618  nChw8c = mkldnn_nChw8c,
619  nChw16c = mkldnn_nChw16c,
620  ncdhw = mkldnn_ncdhw,
621  ndhwc = mkldnn_ndhwc,
622  nCdhw8c = mkldnn_nCdhw8c,
623  nCdhw16c = mkldnn_nCdhw16c,
624  oi = mkldnn_oi,
625  io = mkldnn_io,
626  oiw = mkldnn_oiw,
627  wio = mkldnn_wio,
628  Owi8o = mkldnn_Owi8o,
629  OIw8o8i = mkldnn_OIw8o8i,
630  OIw8i8o = mkldnn_OIw8i8o,
631  OIw16i16o = mkldnn_OIw16i16o,
632  OIw16o16i = mkldnn_OIw16o16i,
633  Oiw16o = mkldnn_Oiw16o,
634  Owi16o = mkldnn_Owi16o,
635  OIw8i16o2i = mkldnn_OIw8i16o2i,
636  OIw8o16i2o = mkldnn_OIw8o16i2o,
637  IOw16o16i = mkldnn_IOw16o16i,
638  oihw = mkldnn_oihw,
639  ihwo = mkldnn_ihwo,
640  hwio = mkldnn_hwio,
641  iohw = mkldnn_iohw,
642  hwio_s8s8 = mkldnn_hwio_s8s8,
643  dhwio = mkldnn_dhwio,
644  oidhw = mkldnn_oidhw,
645  OIdhw8i8o = mkldnn_OIdhw8i8o,
646  OIdhw8o8i = mkldnn_OIdhw8o8i,
647  Odhwi8o = mkldnn_Odhwi8o,
648  OIdhw16i16o = mkldnn_OIdhw16i16o,
649  OIdhw16o16i = mkldnn_OIdhw16o16i,
650  Oidhw16o = mkldnn_Oidhw16o,
651  Odhwi16o = mkldnn_Odhwi16o,
652  oIhw8i = mkldnn_oIhw8i,
653  oIhw16i = mkldnn_oIhw16i,
654  oIdhw8i = mkldnn_oIdhw8i,
655  oIdhw16i = mkldnn_oIdhw16i,
656  OIhw8i8o = mkldnn_OIhw8i8o,
657  OIhw16i16o = mkldnn_OIhw16i16o,
658  OIhw8o8i = mkldnn_OIhw8o8i,
659  OIhw16o16i = mkldnn_OIhw16o16i,
660  IOhw16o16i = mkldnn_IOhw16o16i,
661  OIhw8i16o2i = mkldnn_OIhw8i16o2i,
662  OIdhw8i16o2i = mkldnn_OIdhw8i16o2i,
663  OIhw8o16i2o = mkldnn_OIhw8o16i2o,
664  OIhw4i16o4i = mkldnn_OIhw4i16o4i,
665  OIhw4i16o4i_s8s8 = mkldnn_OIhw4i16o4i_s8s8,
666  Oihw8o = mkldnn_Oihw8o,
667  Oihw16o = mkldnn_Oihw16o,
668  Ohwi8o = mkldnn_Ohwi8o,
669  Ohwi16o = mkldnn_Ohwi16o,
670  OhIw16o4i = mkldnn_OhIw16o4i,
671  goiw = mkldnn_goiw,
672  gOwi8o = mkldnn_gOwi8o,
673  gOIw8o8i = mkldnn_gOIw8o8i,
674  gOIw8i8o = mkldnn_gOIw8i8o,
675  gOIw16i16o = mkldnn_gOIw16i16o,
676  gOIw16o16i = mkldnn_gOIw16o16i,
677  gOiw16o = mkldnn_gOiw16o,
678  gOwi16o = mkldnn_gOwi16o,
679  gOIw8i16o2i = mkldnn_gOIw8i16o2i,
680  gIOw16o16i = mkldnn_gIOw16o16i,
681  gOIw8o16i2o = mkldnn_gOIw8o16i2o,
682  goihw = mkldnn_goihw,
683  hwigo = mkldnn_hwigo,
684  giohw = mkldnn_giohw,
685  hwigo_s8s8 = mkldnn_hwigo_s8s8,
686  gOIdhw8i8o = mkldnn_gOIdhw8i8o,
687  gOIdhw8o8i = mkldnn_gOIdhw8o8i,
688  gOdhwi8o = mkldnn_gOdhwi8o,
689  gOIhw8i8o = mkldnn_gOIhw8i8o,
690  gOIhw16i16o = mkldnn_gOIhw16i16o,
691  gOIhw8i16o2i = mkldnn_gOIhw8i16o2i,
692  gOIdhw8i16o2i = mkldnn_gOIdhw8i16o2i,
693  gOIhw8o16i2o = mkldnn_gOIhw8o16i2o,
694  gOIhw4i16o4i = mkldnn_gOIhw4i16o4i,
695  gOIhw4i16o4i_s8s8 = mkldnn_gOIhw4i16o4i_s8s8,
696  gOihw8o = mkldnn_gOihw8o,
697  gOihw16o = mkldnn_gOihw16o,
698  gOhwi8o = mkldnn_gOhwi8o,
699  gOhwi16o = mkldnn_gOhwi16o,
700  Goihw8g = mkldnn_Goihw8g,
701  Goihw16g = mkldnn_Goihw16g,
702  Goihw16g_s8s8 = mkldnn_Goihw16g_s8s8,
703  gOIhw8o8i = mkldnn_gOIhw8o8i,
704  gOIhw16o16i = mkldnn_gOIhw16o16i,
705  gIOhw16o16i = mkldnn_gIOhw16o16i,
706  gOhIw16o4i = mkldnn_gOhIw16o4i,
707  goidhw = mkldnn_goidhw,
708  gOIdhw16i16o = mkldnn_gOIdhw16i16o,
709  gOIdhw16o16i = mkldnn_gOIdhw16o16i,
710  gOidhw16o = mkldnn_gOidhw16o,
711  gOdhwi16o = mkldnn_gOdhwi16o,
712  ntc = mkldnn_ntc,
713  tnc = mkldnn_tnc,
714  ldsnc = mkldnn_ldsnc,
715  ldigo = mkldnn_ldigo,
716  ldgoi = mkldnn_ldgoi,
717  ldgo = mkldnn_ldgo,
718  rnn_packed = mkldnn_rnn_packed,
719  wino_fmt = mkldnn_wino_fmt,
720  format_last = mkldnn_format_last,
721  };
722 
724  struct desc {
725  friend struct memory;
728 
734  desc(dims adims, data_type adata_type,
735  format aformat) {
736  validate_dims(adims);
738  mkldnn_memory_desc_init(&data, (int)adims.size(),
739  adims.size() == 0 ? nullptr : &adims[0],
740  convert_to_c(adata_type), convert_to_c(aformat)),
741  "could not initialize a memory descriptor");
742  }
743 
747  desc(const mkldnn_memory_desc_t &adata): data(adata) {}
748  };
749 
751  struct primitive_desc : public handle<mkldnn_primitive_desc_t> {
752  friend struct memory;
753 
754  // TODO: make private
756 
758  primitive_desc(const desc &adesc, const engine &aengine) {
759  mkldnn_primitive_desc_t result;
762  &adesc.data, aengine.get()),
763  "could not initialize a memory primitive descriptor");
764  reset(result);
765  }
766 
770  return memory::desc(*memory_d); }
771 
774  size_t get_size() const {
776  }
777 
778  bool operator==(const primitive_desc &other) const {
779  return (0 == mkldnn_memory_primitive_desc_equal(get(),
780  other.get())) ? false : true;
781  }
782 
783  bool operator!=(const primitive_desc &other) const {
784  return !operator==(other);
785  }
786 
787  engine get_engine() { return engine::query(*this); }
788  };
789 
793  memory(const primitive &aprimitive): primitive(aprimitive) {}
797  memory(const primitive_desc &adesc) {
798  mkldnn_primitive_t result;
800  mkldnn_primitive_create(&result, adesc.get(), nullptr, nullptr),
801  "could not create a memory primitive");
802  reset(result);
803  auto _malloc = [](size_t size, int alignment) {
804  void *ptr;
805 #ifdef _WIN32
806  ptr = _aligned_malloc(size, alignment);
807  int rc = ((ptr)? 0 : errno);
808 #else
809  int rc = ::posix_memalign(&ptr, alignment, size);
810 #endif /* _WIN32 */
811  return (rc == 0) ? (char*)ptr : nullptr;
812  };
813  auto _free = [](char* p) {
814 #ifdef _WIN32
815  _aligned_free((void*)p);
816 #else
817  ::free((void*)p);
818 #endif /* _WIN32 */
819  };
820  _handle.reset(_malloc(adesc.get_size(), 4096), _free);
821  set_data_handle(_handle.get());
822  }
823 
824  memory(const primitive_desc &adesc, void *ahandle) {
825  mkldnn_primitive_t result;
827  mkldnn_primitive_create(&result, adesc.get(), nullptr, nullptr),
828  "could not create a memory primitive");
829  reset(result);
830  set_data_handle(ahandle);
831  }
832 
835  primitive_desc adesc;
838  &cdesc),
839  "could not get primitive descriptor from a memory primitive");
840  /* FIXME: no const_cast should be here */
841  adesc.reset(const_cast<mkldnn_primitive_desc_t>(cdesc), true);
842  return adesc;
843  }
844 
847  inline void *get_data_handle() const {
848  void *handle;
850  "could not get native handle");
851  return handle;
852  }
853 
854  inline void set_data_handle(void *handle) const {
856  "could not set native handle");
857  }
858 
859  // Must go away or be private:
861  return static_cast<mkldnn_data_type_t>(adata_type);
862  }
864  return static_cast<mkldnn_memory_format_t>(aformat);
865  }
866 };
867 
869  auto zero = mkldnn_memory_desc_t();
870  zero.primitive_kind = mkldnn_memory;
871  return memory::desc(zero);
872 }
873 
874 inline memory null_memory(engine eng) {
876  return memory({zero, eng}, nullptr);
877 }
878 
880  &aprimitive_desc, int n_inputs, int n_outputs,
881  const std::string &prim_name) {
882  const int n_inputs_expected = mkldnn_primitive_desc_query_s32(
883  aprimitive_desc, mkldnn_query_num_of_inputs_s32, 0);
884  const int n_outputs_expected = mkldnn_primitive_desc_query_s32(
885  aprimitive_desc, mkldnn_query_num_of_outputs_s32, 0);
886  if (n_outputs_expected > n_outputs ) {
887  std::string message = "could not create " + prim_name +
888  " primitive, not enought output parameters";
889  throw error(mkldnn_invalid_arguments, message, nullptr);
890  }
891  if (n_inputs_expected > n_inputs ) {
892  std::string message = "could not create " + prim_name +
893  " primitive, not enought input parameters";
894  throw error(mkldnn_invalid_arguments, message, nullptr);
895  }
896 }
897 
898 
899 inline bool is_null_memory(const const_mkldnn_primitive_t &aprimitive) {
900  const_mkldnn_primitive_desc_t aprimitive_pd;
901  mkldnn_primitive_get_primitive_desc(aprimitive, &aprimitive_pd);
903  aprimitive_pd);
904 
905  return ((aprimitive_md != nullptr) && (aprimitive_md->ndims == 0));
906 }
907 
909  return a == memory::convert_to_c(b);
910 }
912  return !(a == b);
913 }
915  return b == a;
916 }
918  return !(a == b);
919 }
920 
922  return a == memory::convert_to_c(b);
923 }
925  return !(a == b);
926 }
928  return b == a;
929 }
931  return !(a == b);
932 }
933 
935 
941 
942 struct reorder : public primitive {
943  struct primitive_desc : public handle<mkldnn_primitive_desc_t> {
945  const memory::primitive_desc &output) {
946  mkldnn_primitive_desc_t result;
948  &result, input.get(), output.get()),
949  "could not create a reorder primitive descriptor");
950  reset(result);
951  }
952 
954  const memory::primitive_desc &output,
955  const primitive_attr &aattr) {
956  mkldnn_primitive_desc_t result;
958  &result, input.get(), output.get(), aattr.get()),
959  "could not create a reorder primitive descriptor");
960  reset(result);
961  }
962 
963  engine get_engine() { return engine::query(*this); }
964  };
965 
966  reorder(const primitive_desc &aprimitive_desc,
967  const primitive::at &input, const memory &output) {
968  mkldnn_primitive_t result;
969  mkldnn_primitive_at_t inputs[] = { input.data };
970  const_mkldnn_primitive_t outputs[] = { output.get() };
972  aprimitive_desc.get(), inputs, outputs),
973  "could not create a reorder primitive");
974  reset(result);
975  }
976 
977  reorder(const primitive::at &input, const memory &output) {
978  auto input_mpd = memory(input).get_primitive_desc();
979  auto output_mpd = output.get_primitive_desc();
980 
981  auto reorder_d = primitive_desc(input_mpd, output_mpd);
982 
983  mkldnn_primitive_t result;
984  mkldnn_primitive_at_t inputs[] = { input.data };
985  const_mkldnn_primitive_t outputs[] = { output.get() };
987  reorder_d.get(), inputs, outputs),
988  "could not create a reorder primitive");
989  reset(result);
990  }
991 };
992 
994 
1000 
1001 struct view : public primitive {
1002  struct primitive_desc : public handle<mkldnn_primitive_desc_t> {
1004  memory::dims offsets) {
1005  mkldnn_primitive_desc_t result;
1006 
1008  &result, input.get(), &dims[0], &offsets[0]),
1009  "could not create a view primitive descriptor");
1010  reset(result);
1011  }
1012 
1014  memory::primitive_desc adesc;
1015  mkldnn_primitive_desc_t cdesc;
1016  const_mkldnn_primitive_desc_t const_cdesc =
1020  const_cdesc),
1021  "could not clone a dst primitive descriptor");
1022  adesc.reset(cdesc);
1023  return adesc;
1024  }
1025 
1026  engine get_engine() { return engine::query(*this); }
1027  };
1028 
1029  view(const primitive_desc &view_pd, primitive::at input) {
1030  mkldnn_primitive_t result;
1031  mkldnn_primitive_at_t inputs[] = { input.data };
1033  view_pd.get(), inputs, nullptr),
1034  "could not create a view primitive");
1035  reset(result);
1036  }
1037 
1038  view(memory input, memory::dims dims, memory::dims offsets) {
1039  mkldnn_primitive_t result;
1040  primitive_desc view_pd(input.get_primitive_desc(), dims,
1041  offsets);
1042  mkldnn_primitive_at_t inputs[] = { primitive::at(input).data };
1044  view_pd.get(), inputs, nullptr),
1045  "could not create a view primitive");
1046  reset(result);
1047  }
1048 };
1049 
1051 
1057 
1058 struct concat : public primitive {
1059  struct primitive_desc : public handle<mkldnn_primitive_desc_t> {
1060  std::vector<const_mkldnn_primitive_desc_t> cpp_to_c(
1061  std::vector<memory::primitive_desc> inputs) {
1062  std::vector<const_mkldnn_primitive_desc_t> c_api_inputs;
1063  c_api_inputs.reserve(inputs.size());
1064  auto convert_to_c = [](memory::primitive_desc d) { return d.get(); };
1065  std::transform(inputs.begin(), inputs.end(),
1066  std::back_inserter(c_api_inputs), convert_to_c);
1067  return c_api_inputs;
1068  }
1069 
1070  primitive_desc(const memory::desc &output, int concat_dimension,
1071  std::vector<memory::primitive_desc> inputs) {
1072  mkldnn_primitive_desc_t result;
1073 
1074  auto c_api_inputs = cpp_to_c(inputs);
1075 
1077  &result, &output.data, (int)c_api_inputs.size(),
1078  concat_dimension, &c_api_inputs[0]),
1079  "could not create a concat primitive descriptor");
1080  reset(result);
1081  }
1082 
1083  primitive_desc(int concat_dimension,
1084  std::vector<memory::primitive_desc> inputs) {
1085  mkldnn_primitive_desc_t result;
1086 
1087  auto c_api_inputs = cpp_to_c(inputs);
1088 
1090  &result, nullptr, (int)c_api_inputs.size(),
1091  concat_dimension, &c_api_inputs[0]),
1092  "could not create a concat primitive descriptor");
1093  reset(result);
1094  }
1095 
1097  memory::primitive_desc adesc;
1098  mkldnn_primitive_desc_t cdesc;
1099  const_mkldnn_primitive_desc_t const_cdesc =
1102  error::wrap_c_api(mkldnn_primitive_desc_clone(&cdesc, const_cdesc),
1103  "could not clone a dst primitive descriptor");
1104  adesc.reset(cdesc);
1105  return adesc;
1106  }
1107 
1108  engine get_engine() { return engine::query(*this); }
1109  };
1110 
1111  concat(const primitive_desc &concat_pd,
1112  std::vector<primitive::at> &inputs, const memory &output) {
1113  mkldnn_primitive_t result;
1114 
1115  std::vector<mkldnn_primitive_at_t> p_inputs;
1116  for (size_t i = 0; i < inputs.size(); i++)
1117  p_inputs.push_back(inputs[i].data);
1118  const_mkldnn_primitive_t outputs[] = { output.get() };
1119 
1121  concat_pd.get(), &p_inputs[0], outputs),
1122  "could not create a concat primitive");
1123  reset(result);
1124  }
1125 };
1126 
1128 
1134 
1135 struct sum : public primitive {
1136  struct primitive_desc : public handle<mkldnn_primitive_desc_t> {
1137  std::vector<const_mkldnn_primitive_desc_t> cpp_to_c(
1138  std::vector<memory::primitive_desc> inputs) {
1139  std::vector<const_mkldnn_primitive_desc_t> c_api_inputs;
1140  c_api_inputs.reserve(inputs.size());
1141  auto convert_to_c = [](memory::primitive_desc d) { return d.get();};
1142  std::transform(inputs.begin(), inputs.end(),
1143  std::back_inserter(c_api_inputs), convert_to_c);
1144  return c_api_inputs;
1145  }
1146 
1148  const std::vector<float> &scales,
1149  std::vector<memory::primitive_desc> inputs) {
1150  mkldnn_primitive_desc_t result;
1151 
1152  auto c_api_inputs = cpp_to_c(inputs);
1153 
1155  scales.size() == inputs.size() ? mkldnn_success
1157  "number of scales not equal to number of inputs");
1158 
1160  &result, &output.data, (int)c_api_inputs.size(),
1161  &scales[0], &c_api_inputs[0]),
1162  "could not create a sum primitive descriptor");
1163  reset(result);
1164  }
1165 
1166  primitive_desc(const std::vector<float> &scales,
1167  std::vector<memory::primitive_desc> inputs) {
1168  mkldnn_primitive_desc_t result;
1169 
1170  auto c_api_inputs = cpp_to_c(inputs);
1171 
1173  scales.size() == inputs.size() ? mkldnn_success
1175  "number of scales not equal to number of inputs");
1176 
1178  &result, nullptr, (int)c_api_inputs.size(), &scales[0],
1179  &c_api_inputs[0]),
1180  "could not create a sum primitive descriptor");
1181  reset(result);
1182  }
1183 
1185  memory::primitive_desc adesc;
1186  mkldnn_primitive_desc_t cdesc;
1187  const_mkldnn_primitive_desc_t const_cdesc =
1191  const_cdesc),
1192  "could not clone a dst primitive descriptor");
1193  adesc.reset(cdesc);
1194  return adesc;
1195  }
1196 
1197  engine get_engine() { return engine::query(*this); }
1198  };
1199 
1200  sum(const primitive_desc &sum_pd,
1201  std::vector<primitive::at> &inputs, const memory &output) {
1202  mkldnn_primitive_t result;
1203 
1204  std::vector<mkldnn_primitive_at_t> p_inputs;
1205  for (size_t i = 0; i < inputs.size(); i++)
1206  p_inputs.push_back(inputs[i].data);
1207  const_mkldnn_primitive_t outputs[] = { output.get() };
1208 
1210  sum_pd.get(), &p_inputs[0], outputs),
1211  "could not create a sum primitive");
1212  reset(result);
1213  }
1214 };
1215 
1217 
1219 
1222 
1225 
1227 struct primitive_desc : public handle<mkldnn_primitive_desc_t> {
1229  const engine &e, const_mkldnn_primitive_desc_t hint_fwd_pd) {
1230  mkldnn_primitive_desc_iterator_t iterator = nullptr;
1232  &iterator, desc, attr ? attr->get() : nullptr, e.get(),
1233  hint_fwd_pd);
1234  error::wrap_c_api(status,
1235  "could not create a primitive descriptor iterator");
1236  pd_iterator.reset(iterator);
1237  fetch_impl();
1238  }
1239 
1240  engine get_engine() { return engine::query(*this); }
1241 
1243  const_mkldnn_primitive_attr_t const_cattr;
1245  "could not get attributes");
1246  mkldnn_primitive_attr_t cattr;
1247  error::wrap_c_api(mkldnn_primitive_attr_clone(&cattr, const_cattr),
1248  "could not clone attributes");
1249 
1250  primitive_attr attr;
1251  attr.reset(cattr);
1252  return attr;
1253  }
1254 
1256  const char *impl_info_str() const {
1257  const char *res;
1259  mkldnn_query_impl_info_str, 0, &res),
1260  "could not query implementation info string");
1261  return res;
1262  }
1263 
1270  bool next_impl() {
1272  pd_iterator.get());
1273  if (status == mkldnn_iterator_ends) return false;
1274  error::wrap_c_api(status, "primitive descriptor iterator next failed");
1275 
1276  fetch_impl();
1277  return true;
1278  }
1279 
1281  memory::primitive_desc query_mpd(query what, int idx = 0) const {
1282  std::vector<query> valid_w{input_pd, output_pd, src_pd, diff_src_pd,
1284  if (!std::any_of(valid_w.cbegin(), valid_w.cend(),
1285  [=](query q) { return what == q; }))
1286  throw error(mkldnn_invalid_arguments, "invalid memory query");
1287 
1288  const_mkldnn_primitive_desc_t const_cdesc
1290  mkldnn::convert_to_c(what), idx);
1291 
1292  // TODO: is there a better way to inform about this?
1293  if (const_cdesc == nullptr)
1294  throw error(mkldnn_not_required, "queried memory is not required");
1295 
1296  mkldnn_primitive_desc_t cdesc;
1297  error::wrap_c_api(mkldnn_primitive_desc_clone(&cdesc, const_cdesc),
1298  "could not clone a memory primitive descriptor");
1299 
1301  ret.reset(cdesc);
1302  return ret;
1303  }
1304 
1305  // register specialized queries, e.g. src_primitive_desc()
1306 # define REG_QUERY_MPD(name, what, idx) \
1307  memory::primitive_desc name ## _primitive_desc() const \
1308  { return query_mpd(what ## _pd, idx); }
1309 
1310  private:
1311  handle<mkldnn_primitive_desc_iterator_t> pd_iterator;
1312  void fetch_impl() {
1313  mkldnn_primitive_desc_t pd = mkldnn_primitive_desc_iterator_fetch(
1314  pd_iterator.get());
1316  "could not fetch a primitive descriptor from the iterator");
1317  reset(pd);
1318  }
1319 };
1320 
1322 
1328 
1330  struct desc {
1332  desc(prop_kind aprop_kind, algorithm aalgorithm,
1333  const memory::desc &src_desc,
1334  const memory::desc &weights_desc,
1335  const memory::desc &bias_desc,
1336  const memory::desc &dst_desc,
1337  const memory::dims strides,
1338  const memory::dims padding_l,
1339  const memory::dims padding_r,
1340  const padding_kind apadding_kind) {
1341  memory::validate_dims(strides);
1342  memory::validate_dims(padding_l);
1343  memory::validate_dims(padding_r);
1345  mkldnn::convert_to_c(aprop_kind), convert_to_c(aalgorithm),
1346  &src_desc.data, &weights_desc.data, &bias_desc.data,
1347  &dst_desc.data, &strides[0], &padding_l[0], &padding_r[0],
1348  mkldnn::convert_to_c(apadding_kind)),
1349  "could not create a convolution forward descriptor");
1350  }
1351  desc(prop_kind aprop_kind, algorithm aalgorithm,
1352  const memory::desc &src_desc,
1353  const memory::desc &weights_desc,
1354  const memory::desc &dst_desc,
1355  const memory::dims strides,
1356  const memory::dims padding_l,
1357  const memory::dims padding_r,
1358  const padding_kind apadding_kind) {
1359  memory::validate_dims(strides);
1360  memory::validate_dims(padding_l);
1361  memory::validate_dims(padding_r);
1363  mkldnn::convert_to_c(aprop_kind), convert_to_c(aalgorithm),
1364  &src_desc.data, &weights_desc.data, nullptr,
1365  &dst_desc.data, &strides[0], &padding_l[0], &padding_r[0],
1366  mkldnn::convert_to_c(apadding_kind)),
1367  "could not create a convolution forward descriptor");
1368  }
1369  desc(prop_kind aprop_kind, algorithm aalgorithm,
1370  const memory::desc &src_desc,
1371  const memory::desc &weights_desc,
1372  const memory::desc &bias_desc,
1373  const memory::desc &dst_desc,
1374  const memory::dims strides,
1375  const memory::dims dilates,
1376  const memory::dims padding_l,
1377  const memory::dims padding_r,
1378  const padding_kind apadding_kind) {
1379  memory::validate_dims(strides);
1380  memory::validate_dims(dilates);
1381  memory::validate_dims(padding_l);
1382  memory::validate_dims(padding_r);
1385  mkldnn::convert_to_c(aprop_kind), convert_to_c(aalgorithm),
1386  &src_desc.data, &weights_desc.data, &bias_desc.data,
1387  &dst_desc.data, &strides[0], &dilates[0],
1388  &padding_l[0], &padding_r[0],
1389  mkldnn::convert_to_c(apadding_kind)),
1390  "could not create a dilated convolution forward descriptor");
1391  }
1392  desc(prop_kind aprop_kind, algorithm aalgorithm,
1393  const memory::desc &src_desc,
1394  const memory::desc &weights_desc,
1395  const memory::desc &dst_desc,
1396  const memory::dims strides,
1397  const memory::dims dilates,
1398  const memory::dims padding_l,
1399  const memory::dims padding_r,
1400  const padding_kind apadding_kind) {
1401  memory::validate_dims(strides);
1402  memory::validate_dims(dilates);
1403  memory::validate_dims(padding_l);
1404  memory::validate_dims(padding_r);
1407  mkldnn::convert_to_c(aprop_kind), convert_to_c(aalgorithm),
1408  &src_desc.data, &weights_desc.data, nullptr,
1409  &dst_desc.data, &strides[0], &dilates[0],
1410  &padding_l[0], &padding_r[0],
1411  mkldnn::convert_to_c(apadding_kind)),
1412  "could not create a dilated convolution forward descriptor");
1413  }
1414  };
1415 
1417  primitive_desc(const desc &desc, const engine &e)
1418  : mkldnn::primitive_desc(&desc.data, nullptr, e, nullptr) {}
1419 
1420  primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e)
1421  : mkldnn::primitive_desc(&desc.data, &attr, e, nullptr) {}
1422 
1423  REG_QUERY_MPD(src, src, 0);
1424  REG_QUERY_MPD(weights, weights, 0);
1425  REG_QUERY_MPD(bias, weights, 1);
1426  REG_QUERY_MPD(dst, dst, 0);
1427  };
1428 
1429  convolution_forward(const primitive_desc &aprimitive_desc,
1430  const primitive::at &src, const primitive::at &weights,
1431  const primitive::at &bias, const memory &dst) {
1432  mkldnn_primitive_t result;
1433  mkldnn_primitive_at_t inputs[] = { src.data, weights.data,
1434  bias.data };
1435  const_mkldnn_primitive_t outputs[] = { dst.get() };
1437  aprimitive_desc.get(), inputs, outputs),
1438  "could not create a convolution forward bias primitive");
1439  reset(result);
1440  }
1441 
1442  convolution_forward(const primitive_desc &aprimitive_desc,
1443  const primitive::at &src, const primitive::at &weights,
1444  const memory &dst) {
1445  mkldnn_primitive_t result;
1446  mkldnn_primitive_at_t inputs[] = { src.data, weights.data };
1447  const_mkldnn_primitive_t outputs[] = { dst.get() };
1448  check_num_parameters(aprimitive_desc.get(), 2, 1,
1449  "convolution forward");
1451  aprimitive_desc.get(), inputs, outputs),
1452  "could not create a convolution forward primitive");
1453  reset(result);
1454  }
1455 };
1456 
1458  struct desc {
1460  desc(algorithm aalgorithm,
1461  const memory::desc &diff_src_desc,
1462  const memory::desc &weights_desc,
1463  const memory::desc &diff_dst_desc,
1464  const memory::dims strides,
1465  const memory::dims padding_l,
1466  const memory::dims padding_r,
1467  const padding_kind apadding_kind) {
1468  memory::validate_dims(strides);
1469  memory::validate_dims(padding_l);
1470  memory::validate_dims(padding_r);
1472  &data, convert_to_c(aalgorithm), &diff_src_desc.data,
1473  &weights_desc.data, &diff_dst_desc.data,
1474  &strides[0], &padding_l[0], &padding_r[0],
1475  mkldnn::convert_to_c(apadding_kind)),
1476  "could not create a convolution backward data descriptor");
1477  }
1478  desc(algorithm aalgorithm,
1479  const memory::desc &diff_src_desc,
1480  const memory::desc &weights_desc,
1481  const memory::desc &diff_dst_desc,
1482  const memory::dims strides,
1483  const memory::dims dilates,
1484  const memory::dims padding_l,
1485  const memory::dims padding_r,
1486  const padding_kind apadding_kind) {
1487  memory::validate_dims(strides);
1488  memory::validate_dims(dilates);
1489  memory::validate_dims(padding_l);
1490  memory::validate_dims(padding_r);
1493  &data, convert_to_c(aalgorithm), &diff_src_desc.data,
1494  &weights_desc.data, &diff_dst_desc.data,
1495  &strides[0], &dilates[0], &padding_l[0], &padding_r[0],
1496  mkldnn::convert_to_c(apadding_kind)),
1497  "could not create a convolution backward data descriptor");
1498  }
1499  };
1500 
1502  primitive_desc(const desc &desc, const engine &e,
1503  const convolution_forward::primitive_desc &hint_fwd_pd)
1504  : mkldnn::primitive_desc(&desc.data, nullptr, e, hint_fwd_pd.get()) {}
1505 
1506  primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e,
1507  const convolution_forward::primitive_desc &hint_fwd_pd)
1508  : mkldnn::primitive_desc(&desc.data, &attr, e, hint_fwd_pd.get()) {}
1509 
1510  REG_QUERY_MPD(diff_src, diff_src, 0);
1511  REG_QUERY_MPD(weights, weights, 0);
1512  REG_QUERY_MPD(diff_dst, diff_dst, 0);
1513  };
1514 
1516  const primitive::at &diff_dst, const primitive::at &weights,
1517  const memory &diff_src) {
1518  mkldnn_primitive_t result;
1519  mkldnn_primitive_at_t inputs[] = { diff_dst.data, weights.data };
1520  const_mkldnn_primitive_t outputs[] = { diff_src.get() };
1521  check_num_parameters(aprimitive_desc.get(), 2, 1,
1522  "convolution backward data");
1524  aprimitive_desc.get(), inputs, outputs),
1525  "could not create a convolution backward data primitive");
1526  reset(result);
1527  }
1528 };
1529 
1531  struct desc {
1533  desc(algorithm aalgorithm,
1534  const memory::desc &src_desc,
1535  const memory::desc &diff_weights_desc,
1536  const memory::desc &diff_bias_desc,
1537  const memory::desc &diff_dst_desc,
1538  const memory::dims strides,
1539  const memory::dims padding_l,
1540  const memory::dims padding_r,
1541  const padding_kind apadding_kind) {
1542  memory::validate_dims(strides);
1543  memory::validate_dims(padding_l);
1544  memory::validate_dims(padding_r);
1546  &data, convert_to_c(aalgorithm), &src_desc.data,
1547  &diff_weights_desc.data, &diff_bias_desc.data,
1548  &diff_dst_desc.data,
1549  &strides[0], &padding_l[0], &padding_r[0],
1550  mkldnn::convert_to_c(apadding_kind)),
1551  "could not create a convolution backward weights descriptor");
1552  }
1553  desc(algorithm aalgorithm,
1554  const memory::desc &src_desc,
1555  const memory::desc &diff_weights_desc,
1556  const memory::desc &diff_dst_desc,
1557  const memory::dims strides,
1558  const memory::dims padding_l,
1559  const memory::dims padding_r,
1560  const padding_kind apadding_kind) {
1561  memory::validate_dims(strides);
1562  memory::validate_dims(padding_l);
1563  memory::validate_dims(padding_r);
1565  &data, convert_to_c(aalgorithm), &src_desc.data,
1566  &diff_weights_desc.data, nullptr, &diff_dst_desc.data,
1567  &strides[0], &padding_l[0], &padding_r[0],
1568  mkldnn::convert_to_c(apadding_kind)),
1569  "could not create a convolution backward weights descriptor");
1570  }
1571  desc(algorithm aalgorithm,
1572  const memory::desc &src_desc,
1573  const memory::desc &diff_weights_desc,
1574  const memory::desc &diff_bias_desc,
1575  const memory::desc &diff_dst_desc,
1576  const memory::dims strides,
1577  const memory::dims dilates,
1578  const memory::dims padding_l,
1579  const memory::dims padding_r,
1580  const padding_kind apadding_kind) {
1581  memory::validate_dims(strides);
1582  memory::validate_dims(dilates);
1583  memory::validate_dims(padding_l);
1584  memory::validate_dims(padding_r);
1586  &data, convert_to_c(aalgorithm), &src_desc.data,
1587  &diff_weights_desc.data, &diff_bias_desc.data,
1588  &diff_dst_desc.data,
1589  &strides[0], &dilates[0], &padding_l[0], &padding_r[0],
1590  mkldnn::convert_to_c(apadding_kind)),
1591  "could not create a convolution backward weights descriptor");
1592  }
1593  desc(algorithm aalgorithm,
1594  const memory::desc &src_desc,
1595  const memory::desc &diff_weights_desc,
1596  const memory::desc &diff_dst_desc,
1597  const memory::dims strides,
1598  const memory::dims dilates,
1599  const memory::dims padding_l,
1600  const memory::dims padding_r,
1601  const padding_kind apadding_kind) {
1602  memory::validate_dims(strides);
1603  memory::validate_dims(dilates);
1604  memory::validate_dims(padding_l);
1605  memory::validate_dims(padding_r);
1607  &data, convert_to_c(aalgorithm), &src_desc.data,
1608  &diff_weights_desc.data, nullptr, &diff_dst_desc.data,
1609  &strides[0], &dilates[0], &padding_l[0], &padding_r[0],
1610  mkldnn::convert_to_c(apadding_kind)),
1611  "could not create a convolution backward weights descriptor");
1612  }
1613 
1614  };
1615 
1617  primitive_desc(const desc &desc, const engine &e,
1618  const convolution_forward::primitive_desc &hint_fwd_pd)
1619  : mkldnn::primitive_desc(&desc.data, nullptr, e, hint_fwd_pd.get()) {}
1620 
1621  primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e,
1622  const convolution_forward::primitive_desc &hint_fwd_pd)
1623  : mkldnn::primitive_desc(&desc.data, &attr, e, hint_fwd_pd.get()) {}
1624 
1625  REG_QUERY_MPD(src, src, 0);
1626  REG_QUERY_MPD(diff_weights, diff_weights, 0);
1627  REG_QUERY_MPD(diff_bias, diff_weights, 1);
1628  REG_QUERY_MPD(diff_dst, diff_dst, 0);
1629  };
1630 
1632  const primitive::at &src, const primitive::at &diff_dst,
1633  const memory &diff_weights, const memory &diff_bias) {
1634  mkldnn_primitive_t result;
1635  mkldnn_primitive_at_t inputs[] = { src.data, diff_dst.data };
1636  const_mkldnn_primitive_t outputs[] = { diff_weights.get(),
1637  diff_bias.get() };
1638  check_num_parameters(aprimitive_desc.get(), 2, 2,
1639  "convolution backward weights");
1641  aprimitive_desc.get(), inputs, outputs),
1642  "could not create a convolution backward weights primitive");
1643  reset(result);
1644  }
1646  const primitive::at &src, const primitive::at &diff_dst,
1647  const memory &diff_weights) {
1648  mkldnn_primitive_t result;
1649  mkldnn_primitive_at_t inputs[] = { src.data, diff_dst.data };
1650  const_mkldnn_primitive_t outputs[] = { diff_weights.get() };
1651  check_num_parameters(aprimitive_desc.get(), 2, 1,
1652  "convolution backward weights");
1654  aprimitive_desc.get(), inputs, outputs),
1655  "could not create a convolution backward weights primitive");
1656  reset(result);
1657  }
1658 };
1659 
1661 //
1667 
1669  struct desc {
1671  desc(prop_kind aprop_kind, algorithm aalgorithm,
1672  const memory::desc &src_desc,
1673  const memory::desc &weights_desc,
1674  const memory::desc &bias_desc,
1675  const memory::desc &dst_desc,
1676  const memory::dims strides,
1677  const memory::dims padding_l,
1678  const memory::dims padding_r,
1679  const padding_kind apadding_kind) {
1680  memory::validate_dims(strides);
1681  memory::validate_dims(padding_l);
1682  memory::validate_dims(padding_r);
1684  mkldnn::convert_to_c(aprop_kind), convert_to_c(aalgorithm),
1685  &src_desc.data, &weights_desc.data, &bias_desc.data,
1686  &dst_desc.data, &strides[0], &padding_l[0], &padding_r[0],
1687  mkldnn::convert_to_c(apadding_kind)),
1688  "could not create a deconvolution forward descriptor");
1689  }
1690  desc(prop_kind aprop_kind, algorithm aalgorithm,
1691  const memory::desc &src_desc,
1692  const memory::desc &weights_desc,
1693  const memory::desc &dst_desc,
1694  const memory::dims strides,
1695  const memory::dims padding_l,
1696  const memory::dims padding_r,
1697  const padding_kind apadding_kind) {
1698  memory::validate_dims(strides);
1699  memory::validate_dims(padding_l);
1700  memory::validate_dims(padding_r);
1702  mkldnn::convert_to_c(aprop_kind), convert_to_c(aalgorithm),
1703  &src_desc.data, &weights_desc.data, nullptr,
1704  &dst_desc.data, &strides[0], &padding_l[0], &padding_r[0],
1705  mkldnn::convert_to_c(apadding_kind)),
1706  "could not create a deconvolution forward descriptor");
1707  }
1708  desc(prop_kind aprop_kind, algorithm aalgorithm,
1709  const memory::desc &src_desc,
1710  const memory::desc &weights_desc,
1711  const memory::desc &bias_desc,
1712  const memory::desc &dst_desc,
1713  const memory::dims strides,
1714  const memory::dims dilates,
1715  const memory::dims padding_l,
1716  const memory::dims padding_r,
1717  const padding_kind apadding_kind) {
1718  memory::validate_dims(strides);
1719  memory::validate_dims(dilates);
1720  memory::validate_dims(padding_l);
1721  memory::validate_dims(padding_r);
1723  mkldnn::convert_to_c(aprop_kind), convert_to_c(aalgorithm),
1724  &src_desc.data, &weights_desc.data, &bias_desc.data,
1725  &dst_desc.data, &strides[0], &dilates[0], &padding_l[0],
1726  &padding_r[0], mkldnn::convert_to_c(apadding_kind)),
1727  "could not create a dilated deconvolution forward descriptor");
1728  }
1729  desc(prop_kind aprop_kind, algorithm aalgorithm,
1730  const memory::desc &src_desc,
1731  const memory::desc &weights_desc,
1732  const memory::desc &dst_desc,
1733  const memory::dims strides,
1734  const memory::dims dilates,
1735  const memory::dims padding_l,
1736  const memory::dims padding_r,
1737  const padding_kind apadding_kind) {
1738  memory::validate_dims(strides);
1739  memory::validate_dims(dilates);
1740  memory::validate_dims(padding_l);
1741  memory::validate_dims(padding_r);
1743  mkldnn::convert_to_c(aprop_kind), convert_to_c(aalgorithm),
1744  &src_desc.data, &weights_desc.data, nullptr,
1745  &dst_desc.data, &strides[0], &dilates[0], &padding_l[0],
1746  &padding_r[0], mkldnn::convert_to_c(apadding_kind)),
1747  "could not create a dilated deconvolution forward descriptor");
1748  }
1749  };
1750 
1752  primitive_desc(const desc &desc, const engine &e)
1753  : mkldnn::primitive_desc(&desc.data, nullptr, e, nullptr) {}
1754 
1755  primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e)
1756  : mkldnn::primitive_desc(&desc.data, &attr, e, nullptr) {}
1757 
1758  REG_QUERY_MPD(src, src, 0);
1759  REG_QUERY_MPD(weights, weights, 0);
1760  REG_QUERY_MPD(bias, weights, 1);
1761  REG_QUERY_MPD(dst, dst, 0);
1762  };
1763 
1764  deconvolution_forward(const primitive_desc &aprimitive_desc,
1765  const primitive::at &src, const primitive::at &weights,
1766  const primitive::at &bias, const memory &dst) {
1767  mkldnn_primitive_t result;
1768  mkldnn_primitive_at_t inputs[] = { src.data, weights.data,
1769  bias.data };
1770  const_mkldnn_primitive_t outputs[] = { dst.get() };
1771  check_num_parameters(aprimitive_desc.get(), 3, 1,
1772  "deconvolution forward");
1774  aprimitive_desc.get(), inputs, outputs),
1775  "could not create a deconvolution forward bias primitive");
1776  reset(result);
1777  }
1778 
1779  deconvolution_forward(const primitive_desc &aprimitive_desc,
1780  const primitive::at &src, const primitive::at &weights,
1781  const memory &dst) {
1782  mkldnn_primitive_t result;
1783  mkldnn_primitive_at_t inputs[] = { src.data, weights.data };
1784  const_mkldnn_primitive_t outputs[] = { dst.get() };
1785  check_num_parameters(aprimitive_desc.get(), 2, 1,
1786  "deconvolution forward");
1788  aprimitive_desc.get(), inputs, outputs),
1789  "could not create a deconvolution forward primitive");
1790  reset(result);
1791  }
1792 };
1793 
1795  struct desc {
1797  desc(algorithm aalgorithm,
1798  const memory::desc &diff_src_desc,
1799  const memory::desc &weights_desc,
1800  const memory::desc &diff_dst_desc,
1801  const memory::dims strides,
1802  const memory::dims padding_l,
1803  const memory::dims padding_r,
1804  const padding_kind apadding_kind) {
1805  memory::validate_dims(strides);
1806  memory::validate_dims(padding_l);
1807  memory::validate_dims(padding_r);
1809  &data, convert_to_c(aalgorithm), &diff_src_desc.data,
1810  &weights_desc.data, &diff_dst_desc.data,
1811  &strides[0], &padding_l[0], &padding_r[0],
1812  mkldnn::convert_to_c(apadding_kind)),
1813  "could not create a deconvolution backward data descriptor");
1814  }
1815  desc(algorithm aalgorithm,
1816  const memory::desc &diff_src_desc,
1817  const memory::desc &weights_desc,
1818  const memory::desc &diff_dst_desc,
1819  const memory::dims strides,
1820  const memory::dims dilates,
1821  const memory::dims padding_l,
1822  const memory::dims padding_r,
1823  const padding_kind apadding_kind) {
1824  memory::validate_dims(strides);
1825  memory::validate_dims(dilates);
1826  memory::validate_dims(padding_l);
1827  memory::validate_dims(padding_r);
1829  &data, convert_to_c(aalgorithm), &diff_src_desc.data,
1830  &weights_desc.data, &diff_dst_desc.data,
1831  &strides[0], &dilates[0], &padding_l[0], &padding_r[0],
1832  mkldnn::convert_to_c(apadding_kind)),
1833  "could not create a dilated deconvolution backward data descriptor");
1834  }
1835  };
1836 
1838  primitive_desc(const desc &desc, const engine &e,
1839  const deconvolution_forward::primitive_desc &hint_fwd_pd)
1840  : mkldnn::primitive_desc(&desc.data, nullptr, e, hint_fwd_pd.get()) {}
1841 
1842  primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e,
1843  const deconvolution_forward::primitive_desc &hint_fwd_pd)
1844  : mkldnn::primitive_desc(&desc.data, &attr, e, hint_fwd_pd.get()) {}
1845 
1846  REG_QUERY_MPD(diff_src, diff_src, 0);
1847  REG_QUERY_MPD(weights, weights, 0);
1848  REG_QUERY_MPD(diff_dst, diff_dst, 0);
1849  };
1850 
1852  const primitive::at &diff_dst, const primitive::at &weights,
1853  const memory &diff_src) {
1854  mkldnn_primitive_t result;
1855  mkldnn_primitive_at_t inputs[] = { diff_dst.data, weights.data };
1856  const_mkldnn_primitive_t outputs[] = { diff_src.get() };
1857  check_num_parameters(aprimitive_desc.get(), 2, 1,
1858  "deconvolution backward data");
1860  aprimitive_desc.get(), inputs, outputs),
1861  "could not create a deconvolution backward data primitive");
1862  reset(result);
1863  }
1864 };
1865 
1867  struct desc {
1869  desc(algorithm aalgorithm,
1870  const memory::desc &src_desc,
1871  const memory::desc &diff_weights_desc,
1872  const memory::desc &diff_bias_desc,
1873  const memory::desc &diff_dst_desc,
1874  const memory::dims strides,
1875  const memory::dims padding_l,
1876  const memory::dims padding_r,
1877  const padding_kind apadding_kind) {
1878  memory::validate_dims(strides);
1879  memory::validate_dims(padding_l);
1880  memory::validate_dims(padding_r);
1882  &data, convert_to_c(aalgorithm), &src_desc.data,
1883  &diff_weights_desc.data, &diff_bias_desc.data,
1884  &diff_dst_desc.data,
1885  &strides[0], &padding_l[0], &padding_r[0],
1886  mkldnn::convert_to_c(apadding_kind)),
1887  "could not create a deconvolution backward weights descriptor");
1888  }
1889  desc(algorithm aalgorithm,
1890  const memory::desc &src_desc,
1891  const memory::desc &diff_weights_desc,
1892  const memory::desc &diff_dst_desc,
1893  const memory::dims strides,
1894  const memory::dims padding_l,
1895  const memory::dims padding_r,
1896  const padding_kind apadding_kind) {
1897  memory::validate_dims(strides);
1898  memory::validate_dims(padding_l);
1899  memory::validate_dims(padding_r);
1901  &data, convert_to_c(aalgorithm), &src_desc.data,
1902  &diff_weights_desc.data, nullptr, &diff_dst_desc.data,
1903  &strides[0], &padding_l[0], &padding_r[0],
1904  mkldnn::convert_to_c(apadding_kind)),
1905  "could not create a deconvolution backward weights descriptor");
1906  }
1907  desc(algorithm aalgorithm,
1908  const memory::desc &src_desc,
1909  const memory::desc &diff_weights_desc,
1910  const memory::desc &diff_bias_desc,
1911  const memory::desc &diff_dst_desc,
1912  const memory::dims strides,
1913  const memory::dims dilates,
1914  const memory::dims padding_l,
1915  const memory::dims padding_r,
1916  const padding_kind apadding_kind) {
1917  memory::validate_dims(strides);
1918  memory::validate_dims(dilates);
1919  memory::validate_dims(padding_l);
1920  memory::validate_dims(padding_r);
1922  &data, convert_to_c(aalgorithm), &src_desc.data,
1923  &diff_weights_desc.data, &diff_bias_desc.data,
1924  &diff_dst_desc.data,
1925  &strides[0], &dilates[0], &padding_l[0], &padding_r[0],
1926  mkldnn::convert_to_c(apadding_kind)),
1927  "could not create a dilated deconvolution backward weights descriptor");
1928  }
1929  desc(algorithm aalgorithm,
1930  const memory::desc &src_desc,
1931  const memory::desc &diff_weights_desc,
1932  const memory::desc &diff_dst_desc,
1933  const memory::dims strides,
1934  const memory::dims dilates,
1935  const memory::dims padding_l,
1936  const memory::dims padding_r,
1937  const padding_kind apadding_kind) {
1938  memory::validate_dims(strides);
1939  memory::validate_dims(dilates);
1940  memory::validate_dims(padding_l);
1941  memory::validate_dims(padding_r);
1943  &data, convert_to_c(aalgorithm), &src_desc.data,
1944  &diff_weights_desc.data, nullptr, &diff_dst_desc.data,
1945  &strides[0], &dilates[0], &padding_l[0], &padding_r[0],
1946  mkldnn::convert_to_c(apadding_kind)),
1947  "could not create a dilated deconvolution backward weights descriptor");
1948  }
1949  };
1950 
1952  primitive_desc(const desc &desc, const engine &e,
1953  const deconvolution_forward::primitive_desc &hint_fwd_pd)
1954  : mkldnn::primitive_desc(&desc.data, nullptr, e, hint_fwd_pd.get()) {}
1955 
1956  primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e,
1957  const deconvolution_forward::primitive_desc &hint_fwd_pd)
1958  : mkldnn::primitive_desc(&desc.data, &attr, e, hint_fwd_pd.get()) {}
1959 
1960  REG_QUERY_MPD(src, src, 0);
1961  REG_QUERY_MPD(diff_weights, diff_weights, 0);
1962  REG_QUERY_MPD(diff_bias, diff_weights, 1);
1963  REG_QUERY_MPD(diff_dst, diff_dst, 0);
1964  };
1965 
1967  const primitive::at &src, const primitive::at &diff_dst,
1968  const memory &diff_weights, const memory &diff_bias) {
1969  mkldnn_primitive_t result;
1970  mkldnn_primitive_at_t inputs[] = { src.data, diff_dst.data };
1971  const_mkldnn_primitive_t outputs[] = { diff_weights.get(),
1972  diff_bias.get() };
1973  check_num_parameters(aprimitive_desc.get(), 2, 2,
1974  "deconvolution backward weights");
1976  aprimitive_desc.get(), inputs, outputs),
1977  "could not create a deconvolution backward weights primitive");
1978  reset(result);
1979  }
1981  const primitive::at &src, const primitive::at &diff_dst,
1982  const memory &diff_weights) {
1983  mkldnn_primitive_t result;
1984  mkldnn_primitive_at_t inputs[] = { src.data, diff_dst.data };
1985  const_mkldnn_primitive_t outputs[] = { diff_weights.get() };
1986  check_num_parameters(aprimitive_desc.get(), 2, 1,
1987  "deconvolution backward weights");
1989  aprimitive_desc.get(), inputs, outputs),
1990  "could not create a deconvolution backward weights primitive");
1991  reset(result);
1992  }
1993 };
1994 
1996 
2003 
2004 struct lrn_forward : public primitive {
2005  struct desc {
2007  desc(prop_kind aprop_kind, algorithm aalgorithm,
2008  const memory::desc &src_desc,
2009  int local_size, float alpha, float beta, float k)
2010  {
2012  mkldnn::convert_to_c(aprop_kind), convert_to_c(aalgorithm),
2013  &src_desc.data, local_size, alpha, beta, k),
2014  "could not create a lrn forward descriptor");
2015  }
2016  desc(prop_kind aprop_kind, algorithm aalgorithm,
2017  const memory::desc &src_desc,
2018  int local_size, float alpha, float beta)
2019  {
2021  mkldnn::convert_to_c(aprop_kind), convert_to_c(aalgorithm),
2022  &src_desc.data, local_size, alpha, beta, float(1.0)),
2023  "could not create a lrn forward descriptor");
2024  }
2025  };
2026 
2028  primitive_desc(const desc &desc, const engine &e)
2029  : mkldnn::primitive_desc(&desc.data, nullptr, e, nullptr) {}
2030 
2031  primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e)
2032  : mkldnn::primitive_desc(&desc.data, &attr, e, nullptr) {}
2033 
2034  REG_QUERY_MPD(src, src, 0);
2035  REG_QUERY_MPD(dst, dst, 0);
2036  REG_QUERY_MPD(workspace, workspace, 0);
2037  };
2038 
2039  lrn_forward(const primitive_desc &aprimitive_desc,
2040  const primitive::at &src, const memory &workspace,
2041  const memory &dst) {
2042  mkldnn_primitive_t result;
2043  mkldnn_primitive_at_t inputs[] = { src.data };
2044  const_mkldnn_primitive_t outputs[] = { dst.get(),
2045  workspace.get() };
2046  check_num_parameters(aprimitive_desc.get(), 1, 2, "lrn forward");
2048  aprimitive_desc.get(), inputs, outputs),
2049  "could not create a lrn forward primitive");
2050  reset(result);
2051  }
2052 
2053  lrn_forward(const primitive_desc &aprimitive_desc,
2054  const primitive::at &src, const memory &dst) {
2055  mkldnn_primitive_t result;
2056  mkldnn_primitive_at_t inputs[] = { src.data };
2057  const_mkldnn_primitive_t outputs[] = { dst.get() };
2058  check_num_parameters(aprimitive_desc.get(), 1, 1, "lrn forward");
2060  aprimitive_desc.get(), inputs, outputs),
2061  "could not create a lrn forward primitive");
2062  reset(result);
2063  }
2064 };
2065 
2066 struct lrn_backward : public primitive {
2067  struct desc {
2069  desc(algorithm aalgorithm,
2070  const memory::desc &data_desc,
2071  const memory::desc &diff_data_desc,
2072  int local_size, float alpha, float beta, float k)
2073  {
2075  convert_to_c(aalgorithm), &diff_data_desc.data,
2076  &data_desc.data, local_size, alpha, beta, k),
2077  "could not create a lrn backward descriptor");
2078  }
2079  desc(algorithm aalgorithm,
2080  const memory::desc &data_desc,
2081  const memory::desc &diff_data_desc,
2082  int local_size, float alpha, float beta)
2083  {
2085  convert_to_c(aalgorithm), &diff_data_desc.data,
2086  &data_desc.data, local_size, alpha, beta, float(1.0)),
2087  "could not create a lrn backward descriptor");
2088  }
2089  };
2090 
2092  primitive_desc(const desc &desc, const engine &e,
2093  const lrn_forward::primitive_desc &hint_fwd_pd)
2094  : mkldnn::primitive_desc(&desc.data, nullptr, e, hint_fwd_pd.get()) {}
2095 
2096  primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e,
2097  const lrn_forward::primitive_desc &hint_fwd_pd)
2098  : mkldnn::primitive_desc(&desc.data, &attr, e, hint_fwd_pd.get()) {}
2099 
2100  REG_QUERY_MPD(diff_src, diff_src, 0);
2101  REG_QUERY_MPD(diff_dst, diff_dst, 0);
2102  REG_QUERY_MPD(workspace, workspace, 0);
2103  };
2104 
2105  lrn_backward(const primitive_desc &aprimitive_desc,
2106  const primitive::at &src, const primitive::at &diff_dst,
2107  const primitive::at &workspace, const memory &diff_src) {
2108  mkldnn_primitive_t result;
2109  mkldnn_primitive_at_t inputs[] = { src.data, diff_dst.data,
2110  workspace.data };
2111  const_mkldnn_primitive_t outputs[] = { diff_src.get() };
2112  check_num_parameters(aprimitive_desc.get(), 3, 1, "lrn backward");
2114  aprimitive_desc.get(), inputs, outputs),
2115  "could not create a lrn backward primitive");
2116  reset(result);
2117  }
2118 
2119  lrn_backward(const primitive_desc &aprimitive_desc,
2120  const primitive::at &src, const primitive::at &diff_dst,
2121  const memory &diff_src) {
2122  mkldnn_primitive_t result;
2123  mkldnn_primitive_at_t inputs[] = { src.data, diff_dst.data };
2124  const_mkldnn_primitive_t outputs[] = { diff_src.get() };
2125  check_num_parameters(aprimitive_desc.get(), 2, 1, "lrn backward");
2127  aprimitive_desc.get(), inputs, outputs),
2128  "could not create a lrn backward primitive");
2129  reset(result);
2130  }
2131 };
2132 
2134 
2140 
2141 struct pooling_forward : public primitive {
2142  struct desc {
2144  desc(prop_kind aprop_kind, algorithm aalgorithm,
2145  const memory::desc &src_desc,
2146  const memory::desc &dst_desc,
2147  const memory::dims strides,
2148  const memory::dims kernel,
2149  const memory::dims padding_l,
2150  const memory::dims padding_r,
2151  const padding_kind apadding_kind) {
2152  memory::validate_dims(strides);
2153  memory::validate_dims(kernel);
2154  memory::validate_dims(padding_l);
2155  memory::validate_dims(padding_r);
2157  mkldnn::convert_to_c(aprop_kind),
2158  convert_to_c(aalgorithm),
2159  &src_desc.data, &dst_desc.data,
2160  &strides[0], &kernel[0],
2161  &padding_l[0], &padding_r[0],
2162  mkldnn::convert_to_c(apadding_kind)),
2163  "could not init a forward pooling descriptor");
2164  }
2165  };
2166 
2168  primitive_desc(const desc &desc, const engine &e)
2169  : mkldnn::primitive_desc(&desc.data, nullptr, e, nullptr) {}
2170 
2171  primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e)
2172  : mkldnn::primitive_desc(&desc.data, &attr, e, nullptr) {}
2173 
2174  REG_QUERY_MPD(src, src, 0);
2175  REG_QUERY_MPD(dst, dst, 0);
2176  REG_QUERY_MPD(workspace, workspace, 0);
2177  };
2178 
2179  pooling_forward(const primitive_desc &aprimitive_desc, const primitive::at &src,
2180  const memory &dst) {
2181  mkldnn_primitive_t result;
2182  mkldnn_primitive_at_t inputs[] = { src.data };
2183  const_mkldnn_primitive_t outputs[] = { dst.get(), nullptr };
2184  check_num_parameters(aprimitive_desc.get(), 1, 1, "pooling forward");
2186  aprimitive_desc.get(), inputs, outputs),
2187  "could not create a pooling forward primitive");
2188  reset(result);
2189  }
2190 
2191  pooling_forward(const primitive_desc &aprimitive_desc, const primitive::at &src,
2192  const memory &dst, const memory &workspace) {
2193  mkldnn_primitive_t result;
2194  mkldnn_primitive_at_t inputs[] = { src.data };
2195  const_mkldnn_primitive_t outputs[] = { dst.get(), workspace.get() };
2196  check_num_parameters(aprimitive_desc.get(), 1, 2, "pooling forward");
2198  aprimitive_desc.get(), inputs, outputs),
2199  "could not create a pooling forward primitive");
2200  reset(result);
2201  }
2202 };
2203 
2204 struct pooling_backward : public primitive {
2205  struct desc {
2207  desc(algorithm aalgorithm,
2208  const memory::desc &diff_src_desc,
2209  const memory::desc &diff_dst_desc,
2210  const memory::dims &strides,
2211  const memory::dims &kernel,
2212  const memory::dims &padding_l,
2213  const memory::dims &padding_r,
2214  const padding_kind apadding_kind) {
2215  memory::validate_dims(strides);
2216  memory::validate_dims(kernel);
2217  memory::validate_dims(padding_l);
2218  memory::validate_dims(padding_r);
2220  convert_to_c(aalgorithm),
2221  &diff_src_desc.data, &diff_dst_desc.data,
2222  &strides[0], &kernel[0],
2223  &padding_l[0], &padding_r[0],
2224  mkldnn::convert_to_c(apadding_kind)),
2225  "could not init a backward pooling descriptor");
2226  }
2227  };
2228 
2230  primitive_desc(const desc &desc, const engine &e,
2231  const pooling_forward::primitive_desc &hint_fwd_pd)
2232  : mkldnn::primitive_desc(&desc.data, nullptr, e, hint_fwd_pd.get()) {}
2233 
2234  primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e,
2235  const pooling_forward::primitive_desc &hint_fwd_pd)
2236  : mkldnn::primitive_desc(&desc.data, &attr, e, hint_fwd_pd.get()) {}
2237 
2238  REG_QUERY_MPD(diff_src, diff_src, 0);
2239  REG_QUERY_MPD(diff_dst, diff_dst, 0);
2240  REG_QUERY_MPD(workspace, workspace, 0);
2241  };
2242 
2243  pooling_backward(const primitive_desc &aprimitive_desc, const primitive::at &diff_dst,
2244  const memory &diff_src) {
2245  mkldnn_primitive_t result;
2246  mkldnn_primitive_at_t inputs[] = { diff_dst.data };
2247  const_mkldnn_primitive_t outputs[] = { diff_src.get() };
2248  check_num_parameters(aprimitive_desc.get(), 1, 1, "pooling backward");
2250  aprimitive_desc.get(), inputs, outputs),
2251  "could not create a pooling backward primitive");
2252  reset(result);
2253  }
2254 
2255  pooling_backward(const primitive_desc &aprimitive_desc, const primitive::at &diff_dst,
2256  const primitive::at &workspace, const memory &diff_src) {
2257  mkldnn_primitive_t result;
2258  mkldnn_primitive_at_t inputs[] = { diff_dst.data, workspace.data };
2259  const_mkldnn_primitive_t outputs[] = { diff_src.get() };
2260  check_num_parameters(aprimitive_desc.get(), 2, 1, "pooling backward");
2262  aprimitive_desc.get(), inputs, outputs),
2263  "could not create a pooling backward primitive");
2264  reset(result);
2265  }
2266 };
2267 
2269 
2276 
2277 struct eltwise_forward : public primitive {
2278  struct desc {
2280  template <typename T>
2281  desc(prop_kind aprop_kind, algorithm alg_kind,
2282  const memory::desc &src_desc, T alpha = 0, T beta = 0) {
2284  mkldnn::convert_to_c(aprop_kind),
2285  mkldnn::convert_to_c(alg_kind), &src_desc.data,
2286  static_cast<float>(alpha), static_cast<float>(beta)),
2287  "could not create a eltwise forward descriptor");
2288  }
2289  };
2290 
2292  primitive_desc(const desc &desc, const engine &e)
2293  : mkldnn::primitive_desc(&desc.data, nullptr, e, nullptr) {}
2294 
2295  primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e)
2296  : mkldnn::primitive_desc(&desc.data, &attr, e, nullptr) {}
2297 
2298  REG_QUERY_MPD(src, src, 0);
2299  REG_QUERY_MPD(dst, dst, 0);
2300  };
2301 
2302  eltwise_forward(const primitive_desc &aprimitive_desc,
2303  const primitive::at &src, const memory &dst) {
2304  mkldnn_primitive_t result;
2305  mkldnn_primitive_at_t inputs[] = { src.data };
2306  const_mkldnn_primitive_t outputs[] = { dst.get() };
2307  check_num_parameters(aprimitive_desc.get(), 1, 1, "eltwise forward");
2309  aprimitive_desc.get(), inputs, outputs),
2310  "could not create a eltwise forward primitive");
2311  reset(result);
2312  }
2313 };
2314 
2315 struct eltwise_backward : public primitive {
2316  struct desc {
2318 
2319  template <typename T>
2320  desc(algorithm alg_kind, const memory::desc &diff_data_desc,
2321  const memory::desc &data_desc, T alpha = 0, T beta = 0) {
2323  mkldnn::convert_to_c(alg_kind), &diff_data_desc.data,
2324  &data_desc.data, static_cast<float>(alpha),
2325  static_cast<float>(beta)),
2326  "could not create a eltwise backward descriptor");
2327  }
2328  };
2329 
2331  primitive_desc(const desc &desc, const engine &e,
2332  const eltwise_forward::primitive_desc &hint_fwd_pd)
2333  : mkldnn::primitive_desc(&desc.data, nullptr, e, hint_fwd_pd.get()) {}
2334 
2335  primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e,
2336  const eltwise_forward::primitive_desc &hint_fwd_pd)
2337  : mkldnn::primitive_desc(&desc.data, &attr, e, hint_fwd_pd.get()) {}
2338 
2339  REG_QUERY_MPD(src, src, 0);
2340  REG_QUERY_MPD(diff_src, diff_src, 0);
2341  REG_QUERY_MPD(diff_dst, diff_dst, 0);
2342  };
2343 
2344  eltwise_backward(const primitive_desc &aprimitive_desc,
2345  const primitive::at &src, const primitive::at &diff_dst,
2346  const memory &diff_src) {
2347  mkldnn_primitive_t result;
2348  mkldnn_primitive_at_t inputs[] = { src.data, diff_dst.data };
2349  const_mkldnn_primitive_t outputs[] = { diff_src.get() };
2350  check_num_parameters(aprimitive_desc.get(), 2, 1, "eltwise backward");
2352  aprimitive_desc.get(), inputs, outputs),
2353  "could not create a eltwise backward primitive");
2354  reset(result);
2355  }
2356 };
2357 
2359 
2365 
2366 struct softmax_forward : public primitive {
2367  struct desc {
2369  desc(prop_kind aprop_kind, const memory::desc &data_desc,
2370  int softmax_axis) {
2372  mkldnn::convert_to_c(aprop_kind), &data_desc.data,
2373  softmax_axis),
2374  "could not create a softmax forward descriptor");
2375  }
2376  };
2377 
2379  primitive_desc(const desc &desc, const engine &e)
2380  : mkldnn::primitive_desc(&desc.data, nullptr, e, nullptr) {}
2381 
2382  primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e)
2383  : mkldnn::primitive_desc(&desc.data, &attr, e, nullptr) {}
2384 
2385  REG_QUERY_MPD(src, src, 0);
2386  REG_QUERY_MPD(dst, dst, 0);
2387  };
2388 
2389  softmax_forward(const primitive_desc &aprimitive_desc,
2390  const primitive::at &src, const memory &dst) {
2391  mkldnn_primitive_t result;
2392  mkldnn_primitive_at_t inputs[] = { src.data };
2393  const_mkldnn_primitive_t outputs[] = { dst.get() };
2394  check_num_parameters(aprimitive_desc.get(), 1, 1, "softmax forward");
2396  aprimitive_desc.get(), inputs, outputs),
2397  "could not create a softmax forward primitive");
2398  reset(result);
2399  }
2400 };
2401 
2402 struct softmax_backward : public primitive {
2403  struct desc {
2405  desc(const memory::desc &diff_desc, const memory::desc &data_desc,
2406  int softmax_axis) {
2408  &diff_desc.data, &data_desc.data, softmax_axis),
2409  "could not init a backward softmax descriptor");
2410  }
2411  };
2412 
2414  primitive_desc(const desc &desc, const engine &e,
2415  const softmax_forward::primitive_desc &hint_fwd_pd)
2416  : mkldnn::primitive_desc(&desc.data, nullptr, e, hint_fwd_pd.get()) {}
2417 
2418  primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e,
2419  const softmax_forward::primitive_desc &hint_fwd_pd)
2420  : mkldnn::primitive_desc(&desc.data, &attr, e, hint_fwd_pd.get()) {}
2421 
2422  REG_QUERY_MPD(dst, dst, 0);
2423  REG_QUERY_MPD(diff_src, diff_src, 0);
2424  REG_QUERY_MPD(diff_dst, diff_dst, 0);
2425  REG_QUERY_MPD(workspace, workspace, 0);
2426  };
2427 
2428  softmax_backward(const primitive_desc &aprimitive_desc,
2429  const primitive::at &dst, const primitive::at &diff_dst,
2430  const memory &diff_src) {
2431  mkldnn_primitive_t result;
2432  mkldnn_primitive_at_t inputs[] = { dst.data, diff_dst.data };
2433  const_mkldnn_primitive_t outputs[] = { diff_src.get() };
2435  aprimitive_desc.get(), inputs, outputs),
2436  "could not create a softmax backward primitive");
2437  reset(result);
2438  }
2439 };
2440 
2442 
2448 
2450  struct desc {
2452  template <typename T>
2453  desc(prop_kind aprop_kind, const memory::desc &src_desc, T epsilon,
2454  unsigned flags) {
2457  mkldnn::convert_to_c(aprop_kind), &src_desc.data,
2458  static_cast<float>(epsilon), flags),
2459  "could not create a batch normalization forward descriptor");
2460  }
2461  };
2462 
2464  primitive_desc(const desc &desc, const engine &e)
2465  : mkldnn::primitive_desc(&desc.data, nullptr, e, nullptr) {}
2466 
2467  primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e)
2468  : mkldnn::primitive_desc(&desc.data, &attr, e, nullptr) {}
2469 
2470  REG_QUERY_MPD(src, src, 0);
2471  REG_QUERY_MPD(weights, weights, 0);
2472  REG_QUERY_MPD(dst, dst, 0);
2473  REG_QUERY_MPD(workspace, workspace, 0);
2474 
2476  { return stat_primitive_desc(mean); }
2478  { return stat_primitive_desc(var); }
2479 
2480  private:
2481  enum { mean = 1, var = 2, };
2482  memory::primitive_desc stat_primitive_desc(int kind) const {
2486  "could not get a batch-normalization descriptor");
2487  return query_mpd(p->flags & use_global_stats ? src_pd : dst_pd, kind);
2488  }
2489  };
2490 
2492  const primitive::at &src, const primitive::at &mean,
2493  const primitive::at &variance, const primitive::at &weights,
2494  const memory &dst) {
2495  mkldnn_primitive_t result;
2496  mkldnn_primitive_at_t inputs[] = { src.data,
2497  mean.data, variance.data, weights.data };
2498  const_mkldnn_primitive_t outputs[] = { dst.get() };
2499  check_num_parameters(aprimitive_desc.get(), 4, 1,
2500  "batch normalization forward");
2502  aprimitive_desc.get(), inputs, outputs),
2503  "could not create a batch normalization forward primitive");
2504  reset(result);
2505  }
2506 
2508  const primitive::at &src, const primitive::at &mean,
2509  const primitive::at &variance, const memory &dst) {
2510  mkldnn_primitive_t result;
2511  mkldnn_primitive_at_t inputs[] = { src.data,
2512  mean.data, variance.data };
2513  const_mkldnn_primitive_t outputs[] = { dst.get() };
2514  check_num_parameters(aprimitive_desc.get(), 3, 1,
2515  "batch normalization forward");
2517  aprimitive_desc.get(), inputs, outputs),
2518  "could not create a batch normalization forward primitive");
2519  reset(result);
2520  }
2521 
2530  const primitive::at &src, const primitive::at &weights,
2531  const memory &dst, const memory &mean, const memory &variance) {
2532  mkldnn_primitive_t result;
2533  mkldnn_primitive_at_t inputs[] = { src.data, weights.data };
2534  const_mkldnn_primitive_t outputs[] = { dst.get(),
2535  mean.get(), variance.get() };
2536  check_num_parameters(aprimitive_desc.get(), 2, 3,
2537  "batch normalization forward");
2539  aprimitive_desc.get(), inputs, outputs),
2540  "could not create a batch normalization forward primitive");
2541  reset(result);
2542  }
2543 
2545  const primitive::at &src, const primitive::at &weights,
2546  const memory &dst, const memory &mean, const memory &variance,
2547  const memory &workspace) {
2548  mkldnn_primitive_t result;
2549  mkldnn_primitive_at_t inputs[] = { src.data, weights.data };
2550  const_mkldnn_primitive_t outputs[] = { dst.get(),
2551  mean.get(), variance.get(), workspace.get() };
2552  check_num_parameters(aprimitive_desc.get(), 2, 4,
2553  "batch normalization forward");
2555  aprimitive_desc.get(), inputs, outputs),
2556  "could not create a batch normalization forward primitive");
2557  reset(result);
2558  }
2559 
2561  const primitive::at &src, const memory &dst, const memory &mean,
2562  const memory &variance) {
2563  mkldnn_primitive_t result;
2564  mkldnn_primitive_at_t inputs[] = { src.data };
2565  const_mkldnn_primitive_t outputs[] = { dst.get(),
2566  mean.get(), variance.get() };
2567  check_num_parameters(aprimitive_desc.get(), 1, 3,
2568  "batch normalization forward");
2570  aprimitive_desc.get(), inputs, outputs),
2571  "could not create a batch normalization forward primitive");
2572  reset(result);
2573  }
2574 
2587  const primitive::at &src, const memory &dst, const memory &mean,
2588  const memory &variance, const memory &workspace) {
2589  mkldnn_primitive_t result;
2590  mkldnn_primitive_at_t inputs[2] = { src.data };
2591  const_mkldnn_primitive_t outputs[4] = { dst.get(),
2592  mean.get(), variance.get(), workspace.get() };
2593 
2594  if (1) { // check whether this is the `wrong` constructor
2595  const int n_inputs_expected = mkldnn_primitive_desc_query_s32(
2596  aprimitive_desc.get(), mkldnn_query_num_of_inputs_s32, 0);
2597  const int n_outputs_expected = mkldnn_primitive_desc_query_s32(
2598  aprimitive_desc.get(), mkldnn_query_num_of_outputs_s32, 0);
2599  if (n_inputs_expected == 2 && n_outputs_expected == 3) {
2600  // shift parameters, get rid of workspace, and add weights...
2601  auto _weights = dst;
2602  inputs[1] = {_weights.get(), 0};
2603 
2604  auto _dst = mean, _mean = variance, _variance = workspace;
2605  outputs[0] = _dst.get();
2606  outputs[1] = _mean.get();
2607  outputs[2] = _variance.get();
2608  outputs[3] = nullptr;
2609  }
2610  }
2612  aprimitive_desc.get(), inputs, outputs),
2613  "could not create a batch normalization forward primitive");
2614  reset(result);
2615  }
2616 
2618  const primitive::at &src, const primitive::at &weights,
2619  const memory &dst) {
2620  mkldnn_primitive_t result;
2621  mkldnn_primitive_at_t inputs[] = { src.data, weights.data };
2622  const_mkldnn_primitive_t outputs[] = { dst.get() };
2623  check_num_parameters(aprimitive_desc.get(), 2, 1,
2624  "batch normalization forward");
2626  aprimitive_desc.get(), inputs, outputs),
2627  "could not create a batch normalization forward primitive");
2628  reset(result);
2629  }
2630 
2632  const primitive::at &src, const memory &dst) {
2633  mkldnn_primitive_t result;
2634  mkldnn_primitive_at_t inputs[] = { src.data };
2635  const_mkldnn_primitive_t outputs[] = { dst.get() };
2636  check_num_parameters(aprimitive_desc.get(), 1, 1,
2637  "batch normalization forward");
2639  aprimitive_desc.get(), inputs, outputs),
2640  "could not create a batch normalization forward primitive");
2641  reset(result);
2642  }
2643 };
2644 
2646  struct desc {
2648  template <typename T>
2649  desc(prop_kind aprop_kind, const memory::desc &diff_data_desc,
2650  const memory::desc &data_desc, T epsilon, unsigned flags) {
2653  mkldnn::convert_to_c(aprop_kind),
2654  &diff_data_desc.data, &data_desc.data,
2655  static_cast<float>(epsilon), flags),
2656  "could not create a batch normalization backward descriptor");
2657  }
2658  };
2659 
2661  primitive_desc(const desc &desc, const engine &e,
2663  : mkldnn::primitive_desc(&desc.data, nullptr, e, hint_fwd_pd.get()) {}
2664 
2665  primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e,
2667  : mkldnn::primitive_desc(&desc.data, &attr, e, hint_fwd_pd.get()) {}
2668 
2669  REG_QUERY_MPD(src, src, 0);
2670  REG_QUERY_MPD(mean, src, 1);
2671  REG_QUERY_MPD(variance, src, 2);
2672  REG_QUERY_MPD(weights, weights, 0);
2673  REG_QUERY_MPD(dst, dst, 0);
2674  REG_QUERY_MPD(diff_dst, diff_dst, 0);
2675  REG_QUERY_MPD(workspace, workspace, 0);
2676 
2677  REG_QUERY_MPD(diff_src, diff_src, 0);
2678  REG_QUERY_MPD(diff_weights, diff_weights, 0);
2679  };
2680 
2681  // Prop_kind == backward
2683  const primitive::at &src, const primitive::at &mean,
2684  const primitive::at &variance, const primitive::at &diff_dst,
2685  const primitive::at &weights, const memory &diff_src,
2686  const memory &diff_weights) {
2687  mkldnn_primitive_t result;
2688  mkldnn_primitive_at_t inputs[] = { src.data,
2689  mean.data, variance.data, diff_dst.data, weights.data };
2690  const_mkldnn_primitive_t outputs[] = { diff_src.get(),
2691  diff_weights.get() };
2692  check_num_parameters(aprimitive_desc.get(), 5, 2,
2693  "batch normalization backward");
2695  aprimitive_desc.get(), inputs, outputs),
2696  "could not create a batch normalization backward primitive");
2697  reset(result);
2698  }
2699 
2700  // Prop_kind == backward (+ws)
2702  const primitive::at &src, const primitive::at &mean,
2703  const primitive::at &variance, const primitive::at &diff_dst,
2704  const primitive::at &weights, const primitive::at &workspace,
2705  const memory &diff_src, const memory &diff_weights) {
2706  mkldnn_primitive_t result;
2707  mkldnn_primitive_at_t inputs[] = { src.data, mean.data, variance.data,
2708  diff_dst.data, weights.data, workspace.data };
2709  const_mkldnn_primitive_t outputs[] = { diff_src.get(),
2710  diff_weights.get() };
2711  check_num_parameters(aprimitive_desc.get(), 6, 2,
2712  "batch normalization backward");
2714  aprimitive_desc.get(), inputs, outputs),
2715  "could not create a batch normalization backward primitive");
2716  reset(result);
2717  }
2718 
2719  // Prop_kind == backward_data (+ws or +weights)
2724  const primitive::at &src, const primitive::at &mean,
2725  const primitive::at &variance,const primitive::at &diff_dst,
2726  const primitive::at &weights_or_workspace, const memory &diff_src) {
2727  mkldnn_primitive_t result;
2728  mkldnn_primitive_at_t inputs[] = { src.data, mean.data, variance.data,
2729  diff_dst.data, weights_or_workspace.data };
2730  const_mkldnn_primitive_t outputs[] = { diff_src.get() };
2731  check_num_parameters(aprimitive_desc.get(), 5, 1,
2732  "batch normalization backward");
2734  aprimitive_desc.get(), inputs, outputs),
2735  "could not create a batch normalization backward primitive");
2736  reset(result);
2737  }
2738 
2739  // Prop_kind == backward_data
2741  const primitive::at &src, const primitive::at &mean,
2742  const primitive::at &variance, const primitive::at &diff_dst,
2743  const memory &diff_src) {
2744  mkldnn_primitive_t result;
2745  mkldnn_primitive_at_t inputs[] = { src.data,
2746  mean.data, variance.data, diff_dst.data };
2747  const_mkldnn_primitive_t outputs[] = { diff_src.get() };
2748  check_num_parameters(aprimitive_desc.get(), 4, 1,
2749  "batch normalization backward");
2751  aprimitive_desc.get(), inputs, outputs),
2752  "could not create a batch normalization backward primitive");
2753  reset(result);
2754  }
2755 };
2756 
2758 
2764 
2766  struct desc {
2768  desc(prop_kind aprop_kind, const memory::desc &src_desc,
2769  const memory::desc &weights_desc,
2770  const memory::desc &bias_desc,
2771  const memory::desc &dst_desc) {
2774  mkldnn::convert_to_c(aprop_kind), &src_desc.data,
2775  &weights_desc.data, &bias_desc.data, &dst_desc.data),
2776  "could not create a inner product forward descriptor");
2777  }
2778 
2779  desc(prop_kind aprop_kind, const memory::desc &src_desc,
2780  const memory::desc &weights_desc,
2781  const memory::desc &dst_desc) {
2784  mkldnn::convert_to_c(aprop_kind), &src_desc.data,
2785  &weights_desc.data, nullptr, &dst_desc.data),
2786  "could not create a inner product forward descriptor");
2787  }
2788  };
2789 
2791  primitive_desc(const desc &desc, const engine &e)
2792  : mkldnn::primitive_desc(&desc.data, nullptr, e, nullptr) {}
2793 
2794  primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e)
2795  : mkldnn::primitive_desc(&desc.data, &attr, e, nullptr) {}
2796 
2797  REG_QUERY_MPD(src, src, 0);
2798  REG_QUERY_MPD(weights, weights, 0);
2799  REG_QUERY_MPD(bias, weights, 1);
2800  REG_QUERY_MPD(dst, dst, 0);
2801  };
2802 
2803  inner_product_forward(const primitive_desc &aprimitive_desc,
2804  const primitive::at &src, const primitive::at weights,
2805  const primitive::at &bias, const memory &dst) {
2806  mkldnn_primitive_t result;
2807  mkldnn_primitive_at_t inputs[] = { src.data, weights.data,
2808  bias.data };
2809  const_mkldnn_primitive_t outputs[] = { dst.get() };
2810  check_num_parameters(aprimitive_desc.get(), 3, 1,
2811  "inner product forward");
2813  aprimitive_desc.get(), inputs, outputs),
2814  "could not create a inner product forward primitive");
2815  reset(result);
2816  }
2817 
2818  inner_product_forward(const primitive_desc &aprimitive_desc,
2819  const primitive::at &src, const primitive::at weights,
2820  const memory &dst) {
2821  mkldnn_primitive_t result;
2822  mkldnn_primitive_at_t inputs[] = { src.data, weights.data };
2823  const_mkldnn_primitive_t outputs[] = { dst.get() };
2824  check_num_parameters(aprimitive_desc.get(), 2, 1,
2825  "inner product forward");
2827  aprimitive_desc.get(), inputs, outputs),
2828  "could not create a inner product forward primitive");
2829  reset(result);
2830  }
2831 };
2832 
2834  struct desc {
2836  desc(const memory::desc &diff_src_desc,
2837  const memory::desc &weights_desc,
2838  const memory::desc &diff_dst_desc) {
2841  &diff_src_desc.data, &weights_desc.data,
2842  &diff_dst_desc.data),
2843  "could not create a inner product backward data descriptor");
2844  }
2845  };
2846 
2848  primitive_desc(const desc &desc, const engine &e,
2849  const inner_product_forward::primitive_desc &hint_fwd_pd)
2850  : mkldnn::primitive_desc(&desc.data, nullptr, e, hint_fwd_pd.get()) {}
2851 
2852  primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e,
2853  const inner_product_forward::primitive_desc &hint_fwd_pd)
2854  : mkldnn::primitive_desc(&desc.data, &attr, e, hint_fwd_pd.get()) {}
2855 
2856  REG_QUERY_MPD(diff_src, diff_src, 0);
2857  REG_QUERY_MPD(weights, weights, 0);
2858  REG_QUERY_MPD(diff_dst, diff_dst, 0);
2859  };
2860 
2862  const primitive::at &diff_dst, const primitive::at weights,
2863  const memory &diff_src) {
2864  mkldnn_primitive_t result;
2865  mkldnn_primitive_at_t inputs[] = { diff_dst.data, weights.data };
2866  const_mkldnn_primitive_t outputs[] = { diff_src.get() };
2867  check_num_parameters(aprimitive_desc.get(), 2, 1,
2868  "inner product backward data");
2870  aprimitive_desc.get(), inputs, outputs),
2871  "could not create a inner product backward data primitive");
2872  reset(result);
2873  }
2874 };
2875 
2877  struct desc {
2879  desc(const memory::desc &src_desc,
2880  const memory::desc &diff_weights_desc,
2881  const memory::desc &diff_bias_desc,
2882  const memory::desc &diff_dst_desc) {
2885  &data, &src_desc.data, &diff_weights_desc.data,
2886  &diff_bias_desc.data, &diff_dst_desc.data),
2887  "could not create a inner product backward weights descriptor");
2888  }
2889  desc(const memory::desc &src_desc,
2890  const memory::desc &diff_weights_desc,
2891  const memory::desc &diff_dst_desc) {
2894  &data, &src_desc.data, &diff_weights_desc.data,
2895  nullptr, &diff_dst_desc.data),
2896  "could not create a inner product backward weights descriptor");
2897  }
2898  };
2899 
2901  primitive_desc(const desc &desc, const engine &e,
2902  const inner_product_forward::primitive_desc &hint_fwd_pd)
2903  : mkldnn::primitive_desc(&desc.data, nullptr, e, hint_fwd_pd.get()) {}
2904 
2905  primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e,
2906  const inner_product_forward::primitive_desc &hint_fwd_pd)
2907  : mkldnn::primitive_desc(&desc.data, &attr, e, hint_fwd_pd.get()) {}
2908 
2909  REG_QUERY_MPD(src, src, 0);
2910  REG_QUERY_MPD(diff_weights, diff_weights, 0);
2911  REG_QUERY_MPD(diff_bias, diff_weights, 1);
2912  REG_QUERY_MPD(diff_dst, diff_dst, 0);
2913  };
2914 
2916  const primitive::at &src, const primitive::at diff_dst,
2917  const memory &diff_weights) {
2918  mkldnn_primitive_t result;
2919  mkldnn_primitive_at_t inputs[] = { src.data, diff_dst.data };
2920  const_mkldnn_primitive_t outputs[] = { diff_weights.get() };
2921  check_num_parameters(aprimitive_desc.get(), 2, 1,
2922  "inner product backward weights");
2924  aprimitive_desc.get(), inputs, outputs),
2925  "could not create a inner product backward weights primitive");
2926  reset(result);
2927  }
2928 
2930  const primitive::at &src, const primitive::at diff_dst,
2931  const memory &diff_weights, const memory &diff_bias) {
2932  mkldnn_primitive_t result;
2933  mkldnn_primitive_at_t inputs[] = { src.data, diff_dst.data };
2934  const_mkldnn_primitive_t outputs[] =
2935  { diff_weights.get(), diff_bias.get()};
2936  check_num_parameters(aprimitive_desc.get(), 2, 2,
2937  "inner product backward weights");
2939  aprimitive_desc.get(), inputs, outputs),
2940  "could not create a inner product backward weights primitive");
2941  reset(result);
2942  }
2943 };
2944 
2946 
2952 
2953 struct rnn_cell {
2954  struct desc {
2956 
2957  desc(algorithm kind, algorithm activation_f) {
2959  mkldnn::convert_to_c(kind),
2960  mkldnn::convert_to_c(activation_f), 0U, 0, 0),
2961  "could not init an rnn cell descriptor");
2962  }
2964 
2965  operator const mkldnn_rnn_cell_desc_t*() const { return &c_rnn_cell_; }
2966 
2968  { return algorithm(c_rnn_cell_.cell_kind); }
2970  { return algorithm(c_rnn_cell_.activation_kind); }
2971 
2972  float get_alpha() const { return c_rnn_cell_.alpha; }
2973  void set_alpha(float alpha) {
2974  c_rnn_cell_.flags |= mkldnn_rnn_cell_with_relu;
2975  c_rnn_cell_.alpha = alpha;
2976  }
2977 
2978  float get_clipping() const { return c_rnn_cell_.clipping; }
2979  void set_clipping(float clipping) {
2980  c_rnn_cell_.flags |= mkldnn_rnn_cell_with_clipping;
2981  c_rnn_cell_.clipping = clipping;
2982  }
2983 
2984  int get_gates_count() const {
2985  return mkldnn_rnn_cell_get_gates_count(&c_rnn_cell_);
2986  }
2987  int get_state_count() const {
2988  return mkldnn_rnn_cell_get_states_count(&c_rnn_cell_);
2989  }
2990  };
2991 };
2992 
2993 struct rnn_forward : public primitive {
2994  struct desc {
2996  desc(prop_kind aprop_kind, rnn_cell::desc cell,
2997  const rnn_direction direction,
2998  const memory::desc &src_layer_desc,
2999  const memory::desc &src_iter_desc,
3000  const memory::desc &weights_layer_desc,
3001  const memory::desc &weights_iter_desc,
3002  const memory::desc &bias_desc,
3003  const memory::desc &dst_layer_desc,
3004  const memory::desc &dst_iter_desc
3005  ) {
3007  mkldnn::convert_to_c(aprop_kind), cell,
3008  mkldnn::convert_to_c(direction),
3009  &src_layer_desc.data, &src_iter_desc.data,
3010  &weights_layer_desc.data, &weights_iter_desc.data,
3011  &bias_desc.data,
3012  &dst_layer_desc.data, &dst_iter_desc.data),
3013  "could not create an RNN forward descriptor");
3014  }
3015 
3016  };
3017 
3019  primitive_desc(const desc &desc, const engine &e)
3020  : mkldnn::primitive_desc(&desc.data, nullptr, e, nullptr) {}
3021 
3022  primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e)
3023  : mkldnn::primitive_desc(&desc.data, &attr, e, nullptr) {}
3024 
3025  REG_QUERY_MPD(src_layer, src, 0);
3026  REG_QUERY_MPD(src_iter, src, 1);
3027  REG_QUERY_MPD(weights_layer, weights, 0);
3028  REG_QUERY_MPD(weights_iter, weights, 1);
3029  REG_QUERY_MPD(bias, weights, 2);
3030  REG_QUERY_MPD(dst_layer, dst, 0);
3031  REG_QUERY_MPD(dst_iter, dst, 1);
3032  REG_QUERY_MPD(workspace, workspace, 0);
3033  };
3034 
3035  rnn_forward(const primitive_desc &aprimitive_desc,
3036  const primitive::at &src_layer, const primitive::at &src_iter,
3037  const primitive::at &weights_layer,
3038  const primitive::at &weights_iter, const primitive::at &bias,
3039  const memory &dst_layer, const memory &dst_iter,
3040  const memory &workspace) {
3041  mkldnn_primitive_t result;
3042  mkldnn_primitive_at_t inputs[5];
3043  const_mkldnn_primitive_t outputs[3];
3044  int idx=0;
3045  inputs[idx++] = src_layer.data;
3046  if (!is_null_memory(src_iter.data.primitive))
3047  inputs[idx++] = src_iter.data;
3048  inputs[idx++] = weights_layer.data;
3049  inputs[idx++] = weights_iter.data;
3050  if (!is_null_memory(bias.data.primitive)) inputs[idx++] = bias.data;
3051 
3052  idx=0;
3053  outputs[idx++] = dst_layer.get();
3054  if (!is_null_memory(dst_iter.get())) outputs[idx++] = dst_iter.get();
3055  if (!is_null_memory(workspace.get())) outputs[idx++] = workspace.get();
3056 
3058  aprimitive_desc.get(), inputs, outputs),
3059  "could not create an RNN forward primitive");
3060  reset(result);
3061  }
3062 };
3063 
3064 struct rnn_backward : public primitive {
3065  struct desc {
3067  desc(prop_kind aprop_kind, rnn_cell::desc cell,
3068  const rnn_direction direction,
3069  const memory::desc &src_layer_desc,
3070  const memory::desc &src_iter_desc,
3071  const memory::desc &weights_layer_desc,
3072  const memory::desc &weights_iter_desc,
3073  const memory::desc &bias_desc,
3074  const memory::desc &dst_layer_desc,
3075  const memory::desc &dst_iter_desc,
3076  const memory::desc &diff_src_layer_desc,
3077  const memory::desc &diff_src_iter_desc,
3078  const memory::desc &diff_weights_layer_desc,
3079  const memory::desc &diff_weights_iter_desc,
3080  const memory::desc &diff_bias_desc,
3081  const memory::desc &diff_dst_layer_desc,
3082  const memory::desc &diff_dst_iter_desc) {
3084  mkldnn::convert_to_c(aprop_kind), cell,
3085  mkldnn::convert_to_c(direction),
3086  &src_layer_desc.data, &src_iter_desc.data,
3087  &weights_layer_desc.data, &weights_iter_desc.data,
3088  &bias_desc.data,
3089  &dst_layer_desc.data, &dst_iter_desc.data,
3090  &diff_src_layer_desc.data, &diff_src_iter_desc.data,
3091  &diff_weights_layer_desc.data,
3092  &diff_weights_iter_desc.data, &diff_bias_desc.data,
3093  &diff_dst_layer_desc.data, &diff_dst_iter_desc.data),
3094  "could not create an RNN backward descriptor");
3095  }
3096 
3097  };
3098 
3100  primitive_desc(const desc &desc, const engine &e,
3101  const rnn_forward::primitive_desc &hint_fwd_pd)
3102  : mkldnn::primitive_desc(&desc.data, nullptr, e, hint_fwd_pd.get()) {}
3103 
3104  primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e,
3105  const rnn_forward::primitive_desc &hint_fwd_pd)
3106  : mkldnn::primitive_desc(&desc.data, &attr, e, hint_fwd_pd.get()) {}
3107 
3108  REG_QUERY_MPD(src_layer, src, 0);
3109  REG_QUERY_MPD(src_iter, src, 1);
3110  REG_QUERY_MPD(weights_layer, weights, 0);
3111  REG_QUERY_MPD(weights_iter, weights, 1);
3112  REG_QUERY_MPD(bias, weights, 2);
3113  REG_QUERY_MPD(dst_layer, dst, 0);
3114  REG_QUERY_MPD(dst_iter, dst, 1);
3115  REG_QUERY_MPD(workspace, workspace, 0);
3116 
3117  REG_QUERY_MPD(diff_src_layer, diff_src, 0);
3118  REG_QUERY_MPD(diff_src_iter, diff_src, 1);
3119  REG_QUERY_MPD(diff_weights_layer, diff_weights, 0);
3120  REG_QUERY_MPD(diff_weights_iter, diff_weights, 1);
3121  REG_QUERY_MPD(diff_bias, diff_weights, 2);
3122  REG_QUERY_MPD(diff_dst_layer, diff_dst, 0);
3123  REG_QUERY_MPD(diff_dst_iter, diff_dst, 1);
3124  };
3125 
3126  // With last iteration (with and without input src_iter)
3127  rnn_backward(const primitive_desc &aprimitive_desc,
3128  const primitive::at &src_layer,
3129  const primitive::at &src_iter,
3130  const primitive::at &weights_layer,
3131  const primitive::at &weights_iter,
3132  const primitive::at &bias,
3133  const primitive::at &dst_layer,
3134  const primitive::at &dst_iter,
3135  const memory &diff_src_layer,
3136  const memory &diff_src_iter,
3137  const memory &diff_weights_layer,
3138  const memory &diff_weights_iter,
3139  const memory &diff_bias,
3140  const primitive::at &diff_dst_layer,
3141  const primitive::at &diff_dst_iter,
3142  const primitive::at &workspace) {
3143  mkldnn_primitive_t result;
3144  mkldnn_primitive_at_t inputs[10];
3145  const_mkldnn_primitive_t outputs[5];
3146  int idx=0;
3147  inputs[idx++] = src_layer.data;
3148  if (!is_null_memory(src_iter.data.primitive))
3149  inputs[idx++] = src_iter.data;
3150  inputs[idx++] = weights_layer.data;
3151  inputs[idx++] = weights_iter.data;
3152  if (!is_null_memory(bias.data.primitive))
3153  inputs[idx++] = bias.data;
3154  inputs[idx++] = dst_layer.data;
3155  if (!is_null_memory(dst_iter.data.primitive))
3156  inputs[idx++] = dst_iter.data;
3157  inputs[idx++] = diff_dst_layer.data;
3158  if (!is_null_memory(diff_dst_iter.data.primitive))
3159  inputs[idx++] = diff_dst_iter.data;
3160  inputs[idx++] = workspace.data;
3161 
3162  idx = 0;
3163  outputs[idx++] = diff_src_layer.get();
3164  if (!is_null_memory(diff_src_iter.get()))
3165  outputs[idx++] = diff_src_iter.get();
3166  outputs[idx++] = diff_weights_layer.get();
3167  outputs[idx++] = diff_weights_iter.get();
3168  if (!is_null_memory(diff_bias.get())) outputs[idx++] = diff_bias.get();
3170  aprimitive_desc.get(), inputs, outputs),
3171  "could not create an RNN backward primitive");
3172  reset(result);
3173  }
3174 };
3175 
3177 
3183 
3184 struct shuffle_forward : public primitive {
3185  struct desc {
3187  desc(prop_kind aprop_kind, const memory::desc &data_desc,
3188  int axis, int group_size) {
3190  mkldnn::convert_to_c(aprop_kind), &data_desc.data,
3191  axis, group_size),
3192  "could not create a shuffle forward descriptor");
3193  }
3194  };
3195 
3197  primitive_desc(const desc &desc, const engine &e)
3198  : mkldnn::primitive_desc(&desc.data, nullptr, e, nullptr) {}
3199 
3200  REG_QUERY_MPD(src, src, 0);
3201  REG_QUERY_MPD(dst, dst, 0);
3202  };
3203 
3204  shuffle_forward(const primitive_desc &aprimitive_desc,
3205  const primitive::at &src, const memory &dst) {
3206  mkldnn_primitive_t result;
3207  mkldnn_primitive_at_t inputs[] = { src.data };
3208  const_mkldnn_primitive_t outputs[] = { dst.get() };
3209  check_num_parameters(aprimitive_desc.get(), 1, 1, "shuffle forward");
3211  aprimitive_desc.get(), inputs, outputs),
3212  "could not create a shuffle forward primitive");
3213  reset(result);
3214  }
3215 };
3216 
3217 struct shuffle_backward : public primitive {
3218  struct desc {
3220  desc(const memory::desc &diff_data_desc, int axis, int group_size) {
3222  &diff_data_desc.data, axis, group_size),
3223  "could not create a shuffle backward descriptor");
3224  }
3225  };
3226 
3228  primitive_desc(const desc &desc, const engine &e,
3229  const shuffle_forward::primitive_desc &hint_fwd_pd)
3230  : mkldnn::primitive_desc(&desc.data, nullptr, e, hint_fwd_pd.get()) {}
3231 
3232  REG_QUERY_MPD(diff_src, diff_src, 0);
3233  REG_QUERY_MPD(diff_dst, diff_dst, 0);
3234  };
3235 
3236  shuffle_backward(const primitive_desc &aprimitive_desc,
3237  const primitive::at &diff_dst, const memory &diff_src) {
3238  mkldnn_primitive_t result;
3239  mkldnn_primitive_at_t inputs[] = { diff_dst.data};
3240  const_mkldnn_primitive_t outputs[] = { diff_src.get() };
3241  check_num_parameters(aprimitive_desc.get(), 1, 1, "shuffle backward");
3243  aprimitive_desc.get(), inputs, outputs),
3244  "could not create a shuffle backward primitive");
3245  reset(result);
3246  }
3247 };
3248 
3250 
3252 
3258 
3259 #ifndef DOXYGEN_SHOULD_SKIP_THIS
3260 template <> struct handle_traits<mkldnn_stream_t> {
3261  static constexpr auto destructor = &mkldnn_stream_destroy;
3262 };
3263 #endif
3264 
3265 struct stream: public handle<mkldnn_stream_t> {
3266  using handle::handle;
3267 
3271 
3273  return static_cast<mkldnn_stream_kind_t>(akind);
3274  }
3276  stream(kind akind) {
3277  mkldnn_stream_t astream;
3279  convert_to_c(akind)),
3280  "could not create a stream");
3281  reset(astream);
3282  }
3283 
3288  stream &submit(std::vector<primitive> primitives) {
3289  // TODO: find a proper way to convert vector<primitive> to
3290  // vector<mkldnn_primitive_t>
3291  if (primitives.size() == 0) return *this;
3292  std::vector<mkldnn_primitive_t> c_api_primitives;
3293  c_api_primitives.reserve(primitives.size());
3294  auto convert_to_c = [](primitive p) { return p.get(); };
3295  std::transform(primitives.begin(), primitives.end(),
3296  std::back_inserter(c_api_primitives), convert_to_c);
3297 
3298  mkldnn_primitive_t c_api_error_primitive;
3300  mkldnn_stream_submit(get(),
3301  c_api_primitives.size(), &c_api_primitives[0],
3302  &c_api_error_primitive),
3303  "could not submit primitives to a stream",
3304  &c_api_error_primitive);
3305 
3306  return *this;
3307  }
3308 
3315  bool wait(bool block = true) {
3316  mkldnn_primitive_t c_api_error_primitive;
3317  mkldnn_status_t status = mkldnn_stream_wait(get(),
3318  block, &c_api_error_primitive);
3319  if (status != mkldnn_success
3320  && status != mkldnn_try_again)
3321  error::wrap_c_api(status, "could not wait on a stream",
3322  &c_api_error_primitive);
3323  return (status == mkldnn_success);
3324  }
3325 
3327  mkldnn_primitive_t c_api_error_primitive;
3329  mkldnn_stream_rerun(get(), &c_api_error_primitive),
3330  "could not rerun a stream", &c_api_error_primitive);
3331  return *this;
3332  }
3333 };
3334 
3335 #undef REG_QUERY_MPD
3336 
3338 
3340 
3341 } // namespace mkldnn
3342 
3343 #endif
void append_sum(float scale=1.)
Definition: mkldnn.hpp:385
primitive_desc(const desc &desc, const engine &e)
Definition: mkldnn.hpp:2379
Definition: mkldnn.hpp:2330
LRN within a single channel.
Definition: mkldnn_types.h:484
primitive error_primitive
Definition: mkldnn.hpp:164
A descriptor of a Local Response Normalization (LRN) operation.
Definition: mkldnn_types.h:822
desc(algorithm aalgorithm, const memory::desc &diff_src_desc, const memory::desc &weights_desc, const memory::desc &diff_dst_desc, const memory::dims strides, const memory::dims dilates, const memory::dims padding_l, const memory::dims padding_r, const padding_kind apadding_kind)
Definition: mkldnn.hpp:1478
Definition: mkldnn.hpp:342
blocked weights format
Definition: mkldnn_types.h:306
inner_product_forward(const primitive_desc &aprimitive_desc, const primitive::at &src, const primitive::at weights, const memory &dst)
Definition: mkldnn.hpp:2818
primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e)
Definition: mkldnn.hpp:2171
Definition: mkldnn.hpp:269
std::vector< const_mkldnn_primitive_desc_t > cpp_to_c(std::vector< memory::primitive_desc > inputs)
Definition: mkldnn.hpp:1060
blocked weights format
Definition: mkldnn_types.h:309
op descriptor
Definition: mkldnn_types.h:1164
primitive_desc(const memory::desc &output, int concat_dimension, std::vector< memory::primitive_desc > inputs)
Definition: mkldnn.hpp:1070
primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e, const convolution_forward::primitive_desc &hint_fwd_pd)
Definition: mkldnn.hpp:1621
mkldnn_status_t MKLDNN_API mkldnn_convolution_backward_weights_desc_init(mkldnn_convolution_desc_t *conv_desc, mkldnn_alg_kind_t alg_kind, const mkldnn_memory_desc_t *src_desc, const mkldnn_memory_desc_t *diff_weights_desc, const mkldnn_memory_desc_t *diff_bias_desc, const mkldnn_memory_desc_t *diff_dst_desc, const mkldnn_dims_t strides, const mkldnn_dims_t padding_l, const mkldnn_dims_t padding_r, mkldnn_padding_kind_t padding_kind)
Initializes a convolution descriptor conv_desc for backward propagation with respect to weights using...
blocked weights format with additional buffer with size equal to the number of output channels multip...
Definition: mkldnn_types.h:333
Definition: mkldnn.hpp:3064
blocked weights format
Definition: mkldnn_types.h:293
mkldnn_status_t MKLDNN_API mkldnn_primitive_attr_destroy(mkldnn_primitive_attr_t attr)
Deletes an attr.
blocked weights format
Definition: mkldnn_types.h:355
mkldnn_status_t MKLDNN_API mkldnn_sum_primitive_desc_create(mkldnn_primitive_desc_t *sum_primitive_desc, const mkldnn_memory_desc_t *output_desc, int n, const float *scales, const_mkldnn_primitive_desc_t *input_pds)
Creates out-of-place sum_primitive_desc for sum of n inputs multiplied by scale with resulting output...
Definition: mkldnn.hpp:257
A Softmax primitive.
Definition: mkldnn_types.h:428
number of outputs expected
Definition: mkldnn_types.h:1153
bool operator!=(const handle &other) const
Definition: mkldnn.hpp:88
mkldnn_status_t MKLDNN_API mkldnn_stream_destroy(mkldnn_stream_t stream)
Destroys an execution stream.
primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e)
Definition: mkldnn.hpp:3022
convolution_backward_weights(const primitive_desc &aprimitive_desc, const primitive::at &src, const primitive::at &diff_dst, const memory &diff_weights, const memory &diff_bias)
Definition: mkldnn.hpp:1631
batch_normalization_forward(const primitive_desc &aprimitive_desc, const primitive::at &src, const primitive::at &mean, const primitive::at &variance, const primitive::at &weights, const memory &dst)
Definition: mkldnn.hpp:2491
stream & submit(std::vector< primitive > primitives)
Submits a vector of primitives to a stream for computations.
Definition: mkldnn.hpp:3288
bool operator==(const primitive_desc &other) const
Definition: mkldnn.hpp:778
A base class for all primitive descriptors.
Definition: mkldnn.hpp:1227
Definition: mkldnn.hpp:2204
mkldnn_status_t
Status values returned by Intel(R) MKL-DNN functions.
Definition: mkldnn_types.h:39
stream & rerun()
Definition: mkldnn.hpp:3326
Definition: mkldnn.hpp:2167
A descriptor of a convolution operation.
Definition: mkldnn_types.h:675
Definition: mkldnn.hpp:300
desc(prop_kind aprop_kind, const memory::desc &data_desc, int axis, int group_size)
Definition: mkldnn.hpp:3187
Definition: mkldnn.hpp:2142
The operation failed and should be retried.
Definition: mkldnn_types.h:45
memory null_memory(engine eng)
Definition: mkldnn.hpp:874
mkldnn_status_t MKLDNN_API mkldnn_memory_primitive_desc_create(mkldnn_primitive_desc_t *memory_primitive_desc, const mkldnn_memory_desc_t *memory_desc, mkldnn_engine_t engine)
Creates a memory_primitive_desc memory primitive descriptor using memory_desc and engine...
blocked weights format
Definition: mkldnn_types.h:265
mkldnn_status_t MKLDNN_API mkldnn_post_ops_create(mkldnn_post_ops_t *post_ops)
Creates an empty sequence of post operations post_ops.
Definition: mkldnn.hpp:329
mkldnn_status_t MKLDNN_API mkldnn_primitive_desc_destroy(mkldnn_primitive_desc_t primitive_desc)
Deletes a primitive_desc.
desc(algorithm aalgorithm, const memory::desc &src_desc, const memory::desc &diff_weights_desc, const memory::desc &diff_bias_desc, const memory::desc &diff_dst_desc, const memory::dims strides, const memory::dims dilates, const memory::dims padding_l, const memory::dims padding_r, const padding_kind apadding_kind)
Definition: mkldnn.hpp:1571
mkldnn_status_t MKLDNN_API mkldnn_concat_primitive_desc_create(mkldnn_primitive_desc_t *concat_primitive_desc, const mkldnn_memory_desc_t *output_desc, int n, int concat_dimension, const_mkldnn_primitive_desc_t *input_pds)
Creates out-of-place concat_primitive_desc for concatenation of n inputs by concat_dimension with res...
4D RNN bias tensor in the format (num_layers, num_directions, num_gates, output_channels).
Definition: mkldnn_types.h:245
4D data tensor with the physical layout chwn, used in Neon.
Definition: mkldnn_types.h:163
Definition: mkldnn.hpp:265
padding_kind
Definition: mkldnn.hpp:232
The operation failed because of incorrect function arguments.
Definition: mkldnn_types.h:47
Forward data propagation (alias for mkldnn_forward_inference)
Definition: mkldnn_types.h:389
Definition: mkldnn.hpp:2005
desc(algorithm aalgorithm, const memory::desc &src_desc, const memory::desc &diff_weights_desc, const memory::desc &diff_bias_desc, const memory::desc &diff_dst_desc, const memory::dims strides, const memory::dims padding_l, const memory::dims padding_r, const padding_kind apadding_kind)
Definition: mkldnn.hpp:1533
Backward data propagation.
Definition: mkldnn_types.h:395
Definition: mkldnn.hpp:2403
static void validate_dims(std::vector< T > v)
Definition: mkldnn.hpp:586
Definition: mkldnn.hpp:3227
mkldnn_status_t MKLDNN_API mkldnn_primitive_desc_get_attr(const_mkldnn_primitive_desc_t primitive_desc, const_mkldnn_primitive_attr_t *attr)
Returns a constant reference to the attribute of a primitive_desc.
Definition: mkldnn.hpp:3217
mkldnn_status_t MKLDNN_API mkldnn_memory_desc_init(mkldnn_memory_desc_t *memory_desc, int ndims, const mkldnn_dims_t dims, mkldnn_data_type_t data_type, mkldnn_memory_format_t format)
Initializes a memory_desc memory descriptor using ndims, dims, data_type, and data format...
desc(prop_kind aprop_kind, const memory::desc &data_desc, int softmax_axis)
Definition: mkldnn.hpp:2369
Definition: mkldnn.hpp:274
blocked weights format
Definition: mkldnn_types.h:289
Undefined memory format, used for empty memory descriptors.
Definition: mkldnn_types.h:137
const_mkldnn_primitive_desc_t get_primitive_desc() const
Returns the descriptor of the underlying C API primitive.
Definition: mkldnn.hpp:210
concat(const primitive_desc &concat_pd, std::vector< primitive::at > &inputs, const memory &output)
Definition: mkldnn.hpp:1111
memory::desc desc()
Returns the memory primitive descriptor.
Definition: mkldnn.hpp:768
deconvolution_backward_weights(const primitive_desc &aprimitive_desc, const primitive::at &src, const primitive::at &diff_dst, const memory &diff_weights, const memory &diff_bias)
Definition: mkldnn.hpp:1966
mkldnn_status_t MKLDNN_API mkldnn_dilated_convolution_backward_weights_desc_init(mkldnn_convolution_desc_t *conv_desc, mkldnn_alg_kind_t alg_kind, const mkldnn_memory_desc_t *src_desc, const mkldnn_memory_desc_t *diff_weights_desc, const mkldnn_memory_desc_t *diff_bias_desc, const mkldnn_memory_desc_t *diff_dst_desc, const mkldnn_dims_t strides, const mkldnn_dims_t dilates, const mkldnn_dims_t padding_l, const mkldnn_dims_t padding_r, mkldnn_padding_kind_t padding_kind)
Initializes a convolution descriptor conv_desc for backward propagation with respect to weights using...
float alpha
alpha is a negative slope parameter (used only if (flags & mkldnn_rnn_cell_with_relu) != 0) ...
Definition: mkldnn_types.h:926
mkldnn_status_t MKLDNN_API mkldnn_primitive_attr_clone(mkldnn_primitive_attr_t *attr, const_mkldnn_primitive_attr_t existing_attr)
Makes a copy of an existing_attr.
#define TENSOR_MAX_DIMS
Maximum number of dimensions a tensor can have.
Definition: mkldnn_types.h:549
format
Memory format specification. See mkldnn_memory_format_t for a detailed description.
Definition: mkldnn.hpp:605
Definition: mkldnn.hpp:290
4D weights tensor with physical layout oihw, used in Caffe.
Definition: mkldnn_types.h:184
A descriptor of a Softmax operation.
Definition: mkldnn_types.h:772
blocked weights format
Definition: mkldnn_types.h:358
mkldnn_status_t MKLDNN_API mkldnn_primitive_desc_clone(mkldnn_primitive_desc_t *primitive_desc, const_mkldnn_primitive_desc_t existing_primitive_desc)
Makes a copy of a primitive_desc.
softmax_forward(const primitive_desc &aprimitive_desc, const primitive::at &src, const memory &dst)
Definition: mkldnn.hpp:2389
blocked weights format
Definition: mkldnn_types.h:359
blocked data format
Definition: mkldnn_types.h:252
mkldnn_status_t MKLDNN_API mkldnn_memory_get_data_handle(const_mkldnn_primitive_t memory, void **handle)
For a memory primitive, returns the data handle.
Definition: mkldnn.hpp:244
mkldnn_status_t MKLDNN_API mkldnn_convolution_backward_data_desc_init(mkldnn_convolution_desc_t *conv_desc, mkldnn_alg_kind_t alg_kind, const mkldnn_memory_desc_t *diff_src_desc, const mkldnn_memory_desc_t *weights_desc, const mkldnn_memory_desc_t *diff_dst_desc, const mkldnn_dims_t strides, const mkldnn_dims_t padding_l, const mkldnn_dims_t padding_r, mkldnn_padding_kind_t padding_kind)
Initializes a convolution descriptor conv_desc for backward propagation with respect to data using al...
A descriptor of an inner product operation.
Definition: mkldnn_types.h:880
mkldnn_status_t MKLDNN_API mkldnn_post_ops_destroy(mkldnn_post_ops_t post_ops)
Deletes a post_ops sequence.
std::vector< std::remove_extent< mkldnn_dims_t >::type > dims
Definition: mkldnn.hpp:584
3D RNN data tensor in the format (seq_length, batch, input channels).
Definition: mkldnn_types.h:221
primitive_desc(const desc &desc, const engine &e)
Definition: mkldnn.hpp:3197
An opaque structure for a chain of post operations.
An opaque structure to describe a primitive descriptor .
batch normalization descriptor
Definition: mkldnn_types.h:1173
desc(prop_kind aprop_kind, algorithm aalgorithm, const memory::desc &src_desc, const memory::desc &weights_desc, const memory::desc &dst_desc, const memory::dims strides, const memory::dims padding_l, const memory::dims padding_r, const padding_kind apadding_kind)
Definition: mkldnn.hpp:1690
mkldnn_rnn_direction_t
A direction of RNN primitive execution.
Definition: mkldnn_types.h:933
void reset(T t, bool weak=false)
Resets the value of a C handle.
Definition: mkldnn.hpp:79
A convolution primitive.
Definition: mkldnn_types.h:422
primitive_desc(const desc &desc, const engine &e, const deconvolution_forward::primitive_desc &hint_fwd_pd)
Definition: mkldnn.hpp:1838
mkldnn_lrn_desc_t data
Definition: mkldnn.hpp:2068
mkldnn_status_t MKLDNN_API mkldnn_memory_set_data_handle(mkldnn_primitive_t memory, void *handle)
For a memory primitive, sets the data handle.
engine(const mkldnn_engine_t &aengine)
Definition: mkldnn.hpp:538
engine(const handle< mkldnn_primitive_desc_t > &pd)
Definition: mkldnn.hpp:541
engine get_engine()
Definition: mkldnn.hpp:1240
desc(dims adims, data_type adata_type, format aformat)
Constructs a memory descriptor.
Definition: mkldnn.hpp:734
blocked data format
Definition: mkldnn_types.h:253
mkldnn_status_t MKLDNN_API mkldnn_batch_normalization_forward_desc_init(mkldnn_batch_normalization_desc_t *bnrm_desc, mkldnn_prop_kind_t prop_kind, const mkldnn_memory_desc_t *data_desc, float epsilon, unsigned flags)
Initializes a batch normalization descriptor bnrm_desc for forward propagation using prop_kind...
Definition: mkldnn.hpp:225
mkldnn_inner_product_desc_t data
Definition: mkldnn.hpp:2767
sum(const primitive_desc &sum_pd, std::vector< primitive::at > &inputs, const memory &output)
Definition: mkldnn.hpp:1200
An execution engine.
Definition: mkldnn.hpp:503
memory(const primitive_desc &adesc, void *ahandle)
Definition: mkldnn.hpp:824
mkldnn_inner_product_desc_t data
Definition: mkldnn.hpp:2835
mkldnn_status_t MKLDNN_API mkldnn_post_ops_append_eltwise(mkldnn_post_ops_t post_ops, float scale, mkldnn_alg_kind_t alg, float alpha, float beta)
Appends eltwise post operation to the post_ops with given parameters kind, alpha and beta (...
static void wrap_c_api(mkldnn_status_t status, const std::string &message, mkldnn_primitive_t *error_primitive=0)
A convenience function for wrapping calls to the C API. Checks the return status and throws an error ...
Definition: mkldnn.hpp:188
mkldnn_pooling_desc_t data
Definition: mkldnn.hpp:2206
Undefined primitive (XXX: why do we have it?).
Definition: mkldnn_types.h:406
mkldnn_status_t MKLDNN_API mkldnn_deconvolution_backward_data_desc_init(mkldnn_deconvolution_desc_t *conv_desc, mkldnn_alg_kind_t alg_kind, const mkldnn_memory_desc_t *diff_src_desc, const mkldnn_memory_desc_t *weights_desc, const mkldnn_memory_desc_t *diff_dst_desc, const mkldnn_dims_t strides, const mkldnn_dims_t padding_l, const mkldnn_dims_t padding_r, mkldnn_padding_kind_t padding_kind)
Initializes a deconvolution descriptor conv_desc for backward propagation with respect to data using ...
An inner product primitive.
Definition: mkldnn_types.h:436
Packed weights format used in RNN.
Definition: mkldnn_types.h:363
void check_num_parameters(const const_mkldnn_primitive_desc_t &aprimitive_desc, int n_inputs, int n_outputs, const std::string &prim_name)
Definition: mkldnn.hpp:879
Round down.
Definition: mkldnn_types.h:82
4D grouped weights tensor with the physical layout goiw.
Definition: mkldnn_types.h:202
primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e, const softmax_forward::primitive_desc &hint_fwd_pd)
Definition: mkldnn.hpp:2418
desc(prop_kind aprop_kind, algorithm aalgorithm, const memory::desc &src_desc, const memory::desc &weights_desc, const memory::desc &bias_desc, const memory::desc &dst_desc, const memory::dims strides, const memory::dims dilates, const memory::dims padding_l, const memory::dims padding_r, const padding_kind apadding_kind)
Definition: mkldnn.hpp:1708
Definition: mkldnn.hpp:264
round_mode get_int_output_round_mode() const
Definition: mkldnn.hpp:426
primitive_attr()
Definition: mkldnn.hpp:419
Definition: mkldnn_types.h:480
Definition: mkldnn.hpp:2315
mkldnn_status_t MKLDNN_API mkldnn_primitive_attr_set_rnn_weights_qparams(mkldnn_primitive_attr_t attr, int count, int mask, const float *weights_scales)
Sets quantization scales weights_scales for RNN weights tensors.
mkldnn_primitive_at_t MKLDNN_API mkldnn_primitive_at(const_mkldnn_primitive_t primitive, size_t output_index)
Creates an mkldnn_primitive_at_t structure from a primitive and output_index.
primitive_desc(const desc &desc, const engine &e, const softmax_forward::primitive_desc &hint_fwd_pd)
Definition: mkldnn.hpp:2414
mkldnn_softmax_desc_t data
Definition: mkldnn.hpp:2404
Definition: mkldnn.hpp:2378
void get_params_sum(int index, float &scale) const
Definition: mkldnn.hpp:390
Definition: mkldnn.hpp:247
32-bit signed integer.
Definition: mkldnn_types.h:68
primitive_desc(const desc &desc, const engine &e, const inner_product_forward::primitive_desc &hint_fwd_pd)
Definition: mkldnn.hpp:2848
Max pooling.
Definition: mkldnn_types.h:475
desc(prop_kind aprop_kind, algorithm aalgorithm, const memory::desc &src_desc, const memory::desc &weights_desc, const memory::desc &dst_desc, const memory::dims strides, const memory::dims dilates, const memory::dims padding_l, const memory::dims padding_r, const padding_kind apadding_kind)
Definition: mkldnn.hpp:1392
memory::desc zero_md()
Definition: mkldnn.hpp:868
Definition: mkldnn.hpp:336
primitive_desc(const memory::primitive_desc &input, memory::dims dims, memory::dims offsets)
Definition: mkldnn.hpp:1003
mkldnn_status_t MKLDNN_API mkldnn_softmax_forward_desc_init(mkldnn_softmax_desc_t *softmax_desc, mkldnn_prop_kind_t prop_kind, const mkldnn_memory_desc_t *data_desc, int softmax_axis)
Initializes a softmax_desc for forward propagation using prop_kind (possible value are mkldnn_forward...
blocked weights format
Definition: mkldnn_types.h:279
const post_ops get_post_ops() const
Definition: mkldnn.hpp:460
desc(prop_kind aprop_kind, algorithm aalgorithm, const memory::desc &src_desc, const memory::desc &dst_desc, const memory::dims strides, const memory::dims kernel, const memory::dims padding_l, const memory::dims padding_r, const padding_kind apadding_kind)
Definition: mkldnn.hpp:2144
execution engine
Definition: mkldnn_types.h:1149
stream(kind akind)
Constructs a stream.
Definition: mkldnn.hpp:3276
Definition: mkldnn.hpp:1002
mkldnn_status_t MKLDNN_API mkldnn_primitive_desc_iterator_next(mkldnn_primitive_desc_iterator_t iterator)
Iterates over primitive descriptors.
Definition: mkldnn.hpp:335
desc(const memory::desc &diff_src_desc, const memory::desc &weights_desc, const memory::desc &diff_dst_desc)
Definition: mkldnn.hpp:2836
mkldnn_status_t MKLDNN_API mkldnn_pooling_backward_desc_init(mkldnn_pooling_desc_t *pool_desc, mkldnn_alg_kind_t alg_kind, const mkldnn_memory_desc_t *diff_src_desc, const mkldnn_memory_desc_t *diff_dst_desc, const mkldnn_dims_t strides, const mkldnn_dims_t kernel, const mkldnn_dims_t padding_l, const mkldnn_dims_t padding_r, mkldnn_padding_kind_t padding_kind)
Initializes a pooling descriptor pool_desc for backward propagation using alg_kind, memory descriptors, and pooling parameters in spatial domain: strides, kernel sizes, padding_l, padding_r, and padding_kind.
Definition: mkldnn.hpp:2141
blocked weights format
Definition: mkldnn_types.h:286
static mkldnn_memory_format_t convert_to_c(format aformat)
Definition: mkldnn.hpp:863
primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e, const eltwise_forward::primitive_desc &hint_fwd_pd)
Definition: mkldnn.hpp:2335
Definition: mkldnn.hpp:320
mkldnn_status_t MKLDNN_API mkldnn_primitive_attr_create(mkldnn_primitive_attr_t *attr)
Creates an empty (default) attr attribute.
Definition: mkldnn_types.h:911
mkldnn_status_t MKLDNN_API mkldnn_stream_submit(mkldnn_stream_t stream, size_t n, mkldnn_primitive_t primitives[], mkldnn_primitive_t *error_primitive)
Submits primitives to an execution stream.
algorithm
Definition: mkldnn.hpp:255
input memory primitive desc
Definition: mkldnn_types.h:1179
blocked weights format
Definition: mkldnn_types.h:300
mkldnn_shuffle_desc_t data
Definition: mkldnn.hpp:3186
5D grouped weights tensor with the physical layout goihw, used in Caffe.
Definition: mkldnn_types.h:206
const_mkldnn_primitive_t primitive
Primitive to specify the output for.
Definition: mkldnn_types.h:1109
Definition: mkldnn.hpp:289
rnn_forward(const primitive_desc &aprimitive_desc, const primitive::at &src_layer, const primitive::at &src_iter, const primitive::at &weights_layer, const primitive::at &weights_iter, const primitive::at &bias, const memory &dst_layer, const memory &dst_iter, const memory &workspace)
Definition: mkldnn.hpp:3035
mkldnn_status_t MKLDNN_API mkldnn_rnn_cell_desc_init(mkldnn_rnn_cell_desc_t *rnn_cell_desc, mkldnn_alg_kind_t kind, mkldnn_alg_kind_t f, unsigned int flags, float alpha, float clipping)
Initializes a recurrent cell descriptor rnn_cell_desc using rnn_cell_desc, kind (possible values are ...
A descriptor of a element-wise operation.
Definition: mkldnn_types.h:737
rnn descriptor
Definition: mkldnn_types.h:1175
memory::primitive_desc variance_primitive_desc() const
Definition: mkldnn.hpp:2477
An element-wise primitive.
Definition: mkldnn_types.h:426
Definition: mkldnn.hpp:2402
destination grad.
Definition: mkldnn_types.h:1186
algorithm get_cell_kind() const
Definition: mkldnn.hpp:2967
engine get_engine()
Definition: mkldnn.hpp:1197
Definition: mkldnn.hpp:2316
mkldnn_status_t MKLDNN_API mkldnn_stream_wait(mkldnn_stream_t stream, int block, mkldnn_primitive_t *error_primitive)
Waits for all primitives in the execution stream to finish.
mkldnn_alg_kind_t activation_kind
Activation function used.
Definition: mkldnn_types.h:921
memory::primitive_desc dst_primitive_desc() const
Definition: mkldnn.hpp:1184
blocked weights format
Definition: mkldnn_types.h:303
A descriptor for an rnn operation.
Definition: mkldnn_types.h:948
desc(prop_kind aprop_kind, algorithm aalgorithm, const memory::desc &src_desc, const memory::desc &weights_desc, const memory::desc &bias_desc, const memory::desc &dst_desc, const memory::dims strides, const memory::dims dilates, const memory::dims padding_l, const memory::dims padding_r, const padding_kind apadding_kind)
Definition: mkldnn.hpp:1369
Definition: mkldnn.hpp:1058
Definition: mkldnn.hpp:277
Definition: mkldnn.hpp:259
eltwise descriptor
Definition: mkldnn_types.h:1169
batch_normalization_forward(const primitive_desc &aprimitive_desc, const primitive::at &src, const memory &dst, const memory &mean, const memory &variance, const memory &workspace)
Definition: mkldnn.hpp:2586
primitive_desc(const desc &desc, const engine &e)
Definition: mkldnn.hpp:1417
mkldnn_status_t MKLDNN_API mkldnn_primitive_attr_set_rnn_data_qparams(mkldnn_primitive_attr_t attr, const float scale, const float shift)
Sets quantization scale and shift for RNN data tensors.
Definition: mkldnn.hpp:276
batch_normalization_backward(const primitive_desc &aprimitive_desc, const primitive::at &src, const primitive::at &mean, const primitive::at &variance, const primitive::at &diff_dst, const primitive::at &weights_or_workspace, const memory &diff_src)
Definition: mkldnn.hpp:2723
lrn_forward(const primitive_desc &aprimitive_desc, const primitive::at &src, const memory &dst)
Definition: mkldnn.hpp:2053
size_t MKLDNN_API mkldnn_engine_get_count(mkldnn_engine_kind_t kind)
Returns the number of engines of a particular kind.
desc(const memory::desc &src_desc, const memory::desc &diff_weights_desc, const memory::desc &diff_bias_desc, const memory::desc &diff_dst_desc)
Definition: mkldnn.hpp:2879
batch_normalization_flag
Definition: mkldnn.hpp:288
A memory primitive.
Definition: mkldnn_types.h:408
float clipping
clipping parameter (used only if (flags & mkldnn_rnn_cell_with_clipping) != 0)
Definition: mkldnn_types.h:929
blocked weights format
Definition: mkldnn_types.h:288
desc(prop_kind aprop_kind, rnn_cell::desc cell, const rnn_direction direction, const memory::desc &src_layer_desc, const memory::desc &src_iter_desc, const memory::desc &weights_layer_desc, const memory::desc &weights_iter_desc, const memory::desc &bias_desc, const memory::desc &dst_layer_desc, const memory::desc &dst_iter_desc, const memory::desc &diff_src_layer_desc, const memory::desc &diff_src_iter_desc, const memory::desc &diff_weights_layer_desc, const memory::desc &diff_weights_iter_desc, const memory::desc &diff_bias_desc, const memory::desc &diff_dst_layer_desc, const memory::desc &diff_dst_iter_desc)
Definition: mkldnn.hpp:3067
Eltwise: soft_relu.
Definition: mkldnn_types.h:471
void set_post_ops(post_ops ops)
Definition: mkldnn.hpp:469
inner_product_forward(const primitive_desc &aprimitive_desc, const primitive::at &src, const primitive::at weights, const primitive::at &bias, const memory &dst)
Definition: mkldnn.hpp:2803
Definition: mkldnn.hpp:341
Definition: mkldnn.hpp:261
mkldnn_primitive_kind_t MKLDNN_API mkldnn_post_ops_get_kind(const_mkldnn_post_ops_t post_ops, int index)
Returns the type of post operation with index index in given post_ops.
RNN cell.
Definition: mkldnn_types.h:486
primitive_desc(const desc &desc, const engine &e)
Definition: mkldnn.hpp:2168
desc(prop_kind aprop_kind, algorithm aalgorithm, const memory::desc &src_desc, const memory::desc &weights_desc, const memory::desc &dst_desc, const memory::dims strides, const memory::dims dilates, const memory::dims padding_l, const memory::dims padding_r, const padding_kind apadding_kind)
Definition: mkldnn.hpp:1729
bool is_null_memory(const const_mkldnn_primitive_t &aprimitive)
Definition: mkldnn.hpp:899
primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e, const inner_product_forward::primitive_desc &hint_fwd_pd)
Definition: mkldnn.hpp:2852
Definition: mkldnn.hpp:367
blocked weights format
Definition: mkldnn_types.h:315
bool operator==(const handle &other) const
Definition: mkldnn.hpp:87
Definition: mkldnn.hpp:1329
Backward weights propagation.
Definition: mkldnn_types.h:397
void set_int_output_round_mode(round_mode mode)
Definition: mkldnn.hpp:433
mkldnn_rnn_desc_t data
Definition: mkldnn.hpp:2995
blocked weights format
Definition: mkldnn_types.h:354
32-bit/single-precision floating point.
Definition: mkldnn_types.h:66
blocked weights format
Definition: mkldnn_types.h:262
blocked data format
Definition: mkldnn_types.h:251
desc(algorithm aalgorithm, const memory::desc &src_desc, const memory::desc &diff_weights_desc, const memory::desc &diff_dst_desc, const memory::dims strides, const memory::dims padding_l, const memory::dims padding_r, const padding_kind apadding_kind)
Definition: mkldnn.hpp:1553
algorithm get_activation() const
Definition: mkldnn.hpp:2969
pooling_forward(const primitive_desc &aprimitive_desc, const primitive::at &src, const memory &dst)
Definition: mkldnn.hpp:2179
2D weights tensor with physical layout oi.
Definition: mkldnn_types.h:172
Just a sentinel, not real memory format.
Definition: mkldnn_types.h:367
Memory descriptor.
Definition: mkldnn_types.h:634
Definition: mkldnn.hpp:2766
Definition: mkldnn.hpp:303
mkldnn_status_t MKLDNN_API mkldnn_inner_product_backward_data_desc_init(mkldnn_inner_product_desc_t *ip_desc, const mkldnn_memory_desc_t *diff_src_desc, const mkldnn_memory_desc_t *weights_desc, const mkldnn_memory_desc_t *diff_dst_desc)
Initializes an inner product descriptor ip_desc for backward propagation with respect to data using m...
Base class for all computational primitives.
Definition: mkldnn.hpp:106
shuffle_forward(const primitive_desc &aprimitive_desc, const primitive::at &src, const memory &dst)
Definition: mkldnn.hpp:3204
mkldnn_batch_normalization_flag_t
Flags for batch-normalization primititve.
Definition: mkldnn_types.h:503
void set_clipping(float clipping)
Definition: mkldnn.hpp:2979
convolution_backward_weights(const primitive_desc &aprimitive_desc, const primitive::at &src, const primitive::at &diff_dst, const memory &diff_weights)
Definition: mkldnn.hpp:1645
mkldnn_lrn_desc_t data
Definition: mkldnn.hpp:2006
Definition: mkldnn.hpp:2765
desc(prop_kind aprop_kind, const memory::desc &src_desc, T epsilon, unsigned flags)
Definition: mkldnn.hpp:2453
Definition: mkldnn.hpp:280
pooling descriptor
Definition: mkldnn_types.h:1171
Definition: mkldnn.hpp:2205
const mkldnn_memory_desc_t MKLDNN_API * mkldnn_primitive_desc_query_memory_d(const_mkldnn_primitive_desc_t primitive_desc)
Queries primitive descriptor for memory descriptor.
prop_kind
Definition: mkldnn.hpp:240
mkldnn_pooling_desc_t data
Definition: mkldnn.hpp:2143
Definition: mkldnn.hpp:267
blocked weights format
Definition: mkldnn_types.h:261
3D weights tensor with physical layout wio.
Definition: mkldnn_types.h:181
blocked weights format
Definition: mkldnn_types.h:314
mkldnn_status_t MKLDNN_API mkldnn_dilated_deconvolution_forward_desc_init(mkldnn_deconvolution_desc_t *conv_desc, mkldnn_prop_kind_t prop_kind, mkldnn_alg_kind_t alg_kind, const mkldnn_memory_desc_t *src_desc, const mkldnn_memory_desc_t *weights_desc, const mkldnn_memory_desc_t *bias_desc, const mkldnn_memory_desc_t *dst_desc, const mkldnn_dims_t strides, const mkldnn_dims_t dilates, const mkldnn_dims_t padding_l, const mkldnn_dims_t padding_r, mkldnn_padding_kind_t padding_kind)
Initializes a dilated deconvolution descriptor deconv_desc for forward propagation using prop_kind (p...
unsigned int flags
RNN cell flags.
Definition: mkldnn_types.h:923
3D data tensor with the physical layout ncw.
Definition: mkldnn_types.h:151
blocked weights format
Definition: mkldnn_types.h:291
convolution_backward_data(const primitive_desc &aprimitive_desc, const primitive::at &diff_dst, const primitive::at &weights, const memory &diff_src)
Definition: mkldnn.hpp:1515
The operation was successful.
Definition: mkldnn_types.h:41
blocked weights format with additional buffer with size equal to the number of groups and containing ...
Definition: mkldnn_types.h:348
mkldnn_status_t MKLDNN_API mkldnn_engine_create(mkldnn_engine_t *engine, mkldnn_engine_kind_t kind, size_t index)
Creates an engine of particular kind and index.
blocked weights format
Definition: mkldnn_types.h:326
primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e, const inner_product_forward::primitive_desc &hint_fwd_pd)
Definition: mkldnn.hpp:2905
primitive_desc(const desc &desc, const engine &e, const convolution_forward::primitive_desc &hint_fwd_pd)
Definition: mkldnn.hpp:1617
desc(algorithm kind, algorithm activation_f)
Definition: mkldnn.hpp:2957
blocked weights format
Definition: mkldnn_types.h:334
Definition: mkldnn.hpp:326
Definition: mkldnn.hpp:245
primitive_desc(const_mkldnn_op_desc_t desc, const primitive_attr *attr, const engine &e, const_mkldnn_primitive_desc_t hint_fwd_pd)
Definition: mkldnn.hpp:1228
mkldnn_status_t MKLDNN_API mkldnn_primitive_attr_get_int_output_round_mode(const_mkldnn_primitive_attr_t attr, mkldnn_round_mode_t *round_mode)
Returns integer output rounding mode round_mode for a given attr, previously set by mkldnn_primitive_...
blocked weights format
Definition: mkldnn_types.h:352
mkldnn_rnn_desc_t data
Definition: mkldnn.hpp:3066
Backward propagation (with respect to all parameters.
Definition: mkldnn_types.h:393
5D data tensor with the physical layout ndhwc, used in TensorFlow.
Definition: mkldnn_types.h:169
inner_product_backward_weights(const primitive_desc &aprimitive_desc, const primitive::at &src, const primitive::at diff_dst, const memory &diff_weights, const memory &diff_bias)
Definition: mkldnn.hpp:2929
softmax descriptor
Definition: mkldnn_types.h:1170
mkldnn_round_mode_t
Rounding mode.
Definition: mkldnn_types.h:78
A deconvolution primitive.
Definition: mkldnn_types.h:424
Definition: mkldnn.hpp:330
Definition: mkldnn.hpp:275
primitive_desc(const desc &adesc, const engine &aengine)
Constructs a memory primitive descriptor.
Definition: mkldnn.hpp:758
Use global statistics.
Definition: mkldnn_types.h:516
Definition: mkldnn.hpp:31
primitive_desc(int concat_dimension, std::vector< memory::primitive_desc > inputs)
Definition: mkldnn.hpp:1083
blocked weights format
Definition: mkldnn_types.h:292
no query
Definition: mkldnn_types.h:1147
Definition: mkldnn.hpp:1669
blocked weights format
Definition: mkldnn_types.h:341
blocked weights format
Definition: mkldnn_types.h:304
mkldnn_status_t MKLDNN_API mkldnn_convolution_forward_desc_init(mkldnn_convolution_desc_t *conv_desc, mkldnn_prop_kind_t prop_kind, mkldnn_alg_kind_t alg_kind, const mkldnn_memory_desc_t *src_desc, const mkldnn_memory_desc_t *weights_desc, const mkldnn_memory_desc_t *bias_desc, const mkldnn_memory_desc_t *dst_desc, const mkldnn_dims_t strides, const mkldnn_dims_t padding_l, const mkldnn_dims_t padding_r, mkldnn_padding_kind_t padding_kind)
Initializes a convolution descriptor conv_desc for forward propagation using prop_kind (possible valu...
mkldnn_status_t MKLDNN_API mkldnn_view_primitive_desc_create(mkldnn_primitive_desc_t *view_primitive_desc, const_mkldnn_primitive_desc_t memory_primitive_desc, const mkldnn_dims_t dims, const mkldnn_dims_t offsets)
Creates a view_primitive_desc for a given memory_primitive_desc, with dims sizes and offset offsets...
8-bit unsigned integer.
Definition: mkldnn_types.h:74
Definition: mkldnn.hpp:346
Average pooling include padding.
Definition: mkldnn_types.h:477
Unspecified format.
Definition: mkldnn_types.h:140
inner_product_backward_data(const primitive_desc &aprimitive_desc, const primitive::at &diff_dst, const primitive::at weights, const memory &diff_src)
Definition: mkldnn.hpp:2861
Definition: mkldnn.hpp:2027
destination memory primitive desc
Definition: mkldnn_types.h:1185
memory::primitive_desc mean_primitive_desc() const
Definition: mkldnn.hpp:2475
5D RNN weights tensor in the format (num_layers, num_directions, input_channels, num_gates, output_channels).
Definition: mkldnn_types.h:231
GRU cell with linear before reset.
Definition: mkldnn_types.h:499
memory(const primitive_desc &adesc)
Constructs a memory primitive.
Definition: mkldnn.hpp:797
lrn_backward(const primitive_desc &aprimitive_desc, const primitive::at &src, const primitive::at &diff_dst, const primitive::at &workspace, const memory &diff_src)
Definition: mkldnn.hpp:2105
mkldnn_status_t MKLDNN_API mkldnn_shuffle_forward_desc_init(mkldnn_shuffle_desc_t *shuffle_desc, mkldnn_prop_kind_t prop_kind, const mkldnn_memory_desc_t *data_desc, int axis, int group_size)
Initializes a shuffle_desc for forward propagation using prop_kind, memory descriptor data_desc...
Local response normalization (LRN) across multiple channels.
Definition: mkldnn_types.h:482
blocked weights format
Definition: mkldnn_types.h:276
GRU cell.
Definition: mkldnn_types.h:490
Eager stream.
Definition: mkldnn_types.h:1200
primitive_desc(const memory::primitive_desc &input, const memory::primitive_desc &output, const primitive_attr &aattr)
Definition: mkldnn.hpp:953
void set_output_scales(int mask, const std::vector< float > &scales)
Definition: mkldnn.hpp:453
at(const primitive &aprimitive, size_t at=0)
Constructs a wrapper specifying aprimitive output with index at.
Definition: mkldnn.hpp:143
implementation name
Definition: mkldnn_types.h:1160
desc(algorithm aalgorithm, const memory::desc &src_desc, const memory::desc &diff_weights_desc, const memory::desc &diff_bias_desc, const memory::desc &diff_dst_desc, const memory::dims strides, const memory::dims dilates, const memory::dims padding_l, const memory::dims padding_r, const padding_kind apadding_kind)
Definition: mkldnn.hpp:1907
Definition: mkldnn.hpp:1330
desc(const memory::desc &diff_data_desc, int axis, int group_size)
Definition: mkldnn.hpp:3220
Definition: mkldnn.hpp:3218
Definition: mkldnn.hpp:256
pooling_backward(const primitive_desc &aprimitive_desc, const primitive::at &diff_dst, const memory &diff_src)
Definition: mkldnn.hpp:2243
mkldnn_status_t MKLDNN_API mkldnn_primitive_attr_get_output_scales(const_mkldnn_primitive_attr_t attr, int *count, int *mask, const float **scales)
Returns count, correspondence scale mask, and pointer to a constant floating point array of output sc...
3D weights tensor with physical layout oiw.
Definition: mkldnn_types.h:178
Eltwise: parametric exponential linear unit (elu)
Definition: mkldnn_types.h:459
kind
Kinds of engines.
Definition: mkldnn.hpp:508
Definition: mkldnn.hpp:2067
Definition: mkldnn.hpp:2833
primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e)
Definition: mkldnn.hpp:2382
Intel(R) MKL-DNN exception class.
Definition: mkldnn.hpp:161
round_mode
Definition: mkldnn.hpp:223
bool operator==(mkldnn_data_type_t a, memory::data_type b)
Definition: mkldnn.hpp:908
mkldnn_deconvolution_desc_t data
Definition: mkldnn.hpp:1796
Eltwise: ReLU.
Definition: mkldnn_types.h:455
Definition: mkldnn.hpp:2366
mkldnn_convolution_desc_t data
Definition: mkldnn.hpp:1331
Definition: mkldnn.hpp:233
1D data tensor.
Definition: mkldnn_types.h:146
mkldnn_primitive_at_t data
The underlying C API structure.
Definition: mkldnn.hpp:136
memory::primitive_desc query_mpd(query what, int idx=0) const
Queries and returns requested memory primitive descriptor.
Definition: mkldnn.hpp:1281
primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e, const batch_normalization_forward::primitive_desc &hint_fwd_pd)
Definition: mkldnn.hpp:2665
mkldnn_status_t MKLDNN_API mkldnn_primitive_attr_set_post_ops(mkldnn_primitive_attr_t attr, const_mkldnn_post_ops_t post_ops)
Sets configured post_ops to an attribute attr for future use (when primitive descriptor is being crea...
primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e, const rnn_forward::primitive_desc &hint_fwd_pd)
Definition: mkldnn.hpp:3104
primitive_desc(const desc &desc, const engine &e, const shuffle_forward::primitive_desc &hint_fwd_pd)
Definition: mkldnn.hpp:3228
4D weights tensor with physical layout ihwo.
Definition: mkldnn_types.h:190
mkldnn_eltwise_desc_t data
Definition: mkldnn.hpp:2317
mkldnn_memory_format_t
Memory format specification.
Definition: mkldnn_types.h:135
Definition: mkldnn.hpp:1001
Eltwise: square.
Definition: mkldnn_types.h:461
Definition: mkldnn.hpp:1135
desc(prop_kind aprop_kind, algorithm aalgorithm, const memory::desc &src_desc, const memory::desc &weights_desc, const memory::desc &dst_desc, const memory::dims strides, const memory::dims padding_l, const memory::dims padding_r, const padding_kind apadding_kind)
Definition: mkldnn.hpp:1351
memory::primitive_desc dst_primitive_desc() const
Definition: mkldnn.hpp:1013
Definition: mkldnn.hpp:281
mkldnn_status_t MKLDNN_API mkldnn_eltwise_forward_desc_init(mkldnn_eltwise_desc_t *eltwise_desc, mkldnn_prop_kind_t prop_kind, mkldnn_alg_kind_t alg_kind, const mkldnn_memory_desc_t *data_desc, float alpha, float beta)
Initializes a eltwise_desc for forward propagation using prop_kind (possible values are mkldnn_forwar...
int MKLDNN_API mkldnn_memory_primitive_desc_equal(const_mkldnn_primitive_desc_t lhs, const_mkldnn_primitive_desc_t rhs)
Compares two descriptors of memory primitives.
void set_rnn_data_qparams(const float scale, const float shift)
Definition: mkldnn.hpp:474
static mkldnn_data_type_t convert_to_c(data_type adata_type)
Definition: mkldnn.hpp:860
4D data tensor with the physical layout nhwc, used in TensorFlow.
Definition: mkldnn_types.h:160
void set_data_handle(void *handle) const
Definition: mkldnn.hpp:854
batch_normalization_forward(const primitive_desc &aprimitive_desc, const primitive::at &src, const memory &dst, const memory &mean, const memory &variance)
Definition: mkldnn.hpp:2560
Definition: mkldnn.hpp:268
desc(algorithm aalgorithm, const memory::desc &data_desc, const memory::desc &diff_data_desc, int local_size, float alpha, float beta, float k)
Definition: mkldnn.hpp:2069
Backward bias propagation.
Definition: mkldnn_types.h:399
Definition: mkldnn.hpp:942
desc(prop_kind aprop_kind, algorithm aalgorithm, const memory::desc &src_desc, int local_size, float alpha, float beta)
Definition: mkldnn.hpp:2016
blocked weights format
Definition: mkldnn_types.h:349
Use scale and shift parameters.
Definition: mkldnn_types.h:529
desc(prop_kind aprop_kind, algorithm aalgorithm, const memory::desc &src_desc, const memory::desc &weights_desc, const memory::desc &bias_desc, const memory::desc &dst_desc, const memory::dims strides, const memory::dims padding_l, const memory::dims padding_r, const padding_kind apadding_kind)
Definition: mkldnn.hpp:1671
mkldnn_status_t MKLDNN_API mkldnn_deconvolution_forward_desc_init(mkldnn_deconvolution_desc_t *conv_desc, mkldnn_prop_kind_t prop_kind, mkldnn_alg_kind_t alg_kind, const mkldnn_memory_desc_t *src_desc, const mkldnn_memory_desc_t *weights_desc, const mkldnn_memory_desc_t *bias_desc, const mkldnn_memory_desc_t *dst_desc, const mkldnn_dims_t strides, const mkldnn_dims_t padding_l, const mkldnn_dims_t padding_r, mkldnn_padding_kind_t padding_kind)
Initializes a deconvolution descriptor deconv_desc for forward propagation using prop_kind (possible ...
query
Definition: mkldnn.hpp:311
Definition: mkldnn.hpp:279
weights format with additional buffer size equal to the number of output channels multiplied by numbe...
Definition: mkldnn_types.h:325
mkldnn_status_t MKLDNN_API mkldnn_primitive_desc_query(const_mkldnn_primitive_desc_t primitive_desc, mkldnn_query_t what, int index, void *result)
Queries primitive descriptor.
float get_alpha() const
Definition: mkldnn.hpp:2972
blocked weights format
Definition: mkldnn_types.h:275
blocked weights format
Definition: mkldnn_types.h:335
A descriptor of a shuffle operation.
Definition: mkldnn_types.h:720
void get_params_eltwise(int index, float &scale, algorithm &alg, float &alpha, float &beta) const
Definition: mkldnn.hpp:402
Definition: mkldnn_types.h:943
mkldnn_status_t MKLDNN_API mkldnn_dilated_deconvolution_backward_weights_desc_init(mkldnn_deconvolution_desc_t *conv_desc, mkldnn_alg_kind_t alg_kind, const mkldnn_memory_desc_t *src_desc, const mkldnn_memory_desc_t *diff_weights_desc, const mkldnn_memory_desc_t *diff_bias_desc, const mkldnn_memory_desc_t *diff_dst_desc, const mkldnn_dims_t strides, const mkldnn_dims_t dilates, const mkldnn_dims_t padding_l, const mkldnn_dims_t padding_r, mkldnn_padding_kind_t padding_kind)
Initializes a dilated deconvolution descriptor conv_desc for backward propagation with respect to wei...
mkldnn_eltwise_desc_t data
Definition: mkldnn.hpp:2279
primitive_desc(const desc &desc, const engine &e, const deconvolution_forward::primitive_desc &hint_fwd_pd)
Definition: mkldnn.hpp:1952
Definition: mkldnn.hpp:418
blocked weights format
Definition: mkldnn_types.h:343
blocked weights format
Definition: mkldnn_types.h:311
int get_gates_count() const
Definition: mkldnn.hpp:2984
int ndims
Number of dimensions.
Definition: mkldnn_types.h:639
reorder(const primitive_desc &aprimitive_desc, const primitive::at &input, const memory &output)
Definition: mkldnn.hpp:966
Definition: mkldnn.hpp:2004
Definition: mkldnn.hpp:1059
kind
A proxy to C primitive kind enum.
Definition: mkldnn.hpp:113
5D grouped weights tensor with the physical layout giohw.
Definition: mkldnn_types.h:213
void set_alpha(float alpha)
Definition: mkldnn.hpp:2973
mkldnn_status_t MKLDNN_API mkldnn_eltwise_backward_desc_init(mkldnn_eltwise_desc_t *eltwise_desc, mkldnn_alg_kind_t alg_kind, const mkldnn_memory_desc_t *diff_data_desc, const mkldnn_memory_desc_t *data_desc, float alpha, float beta)
Initializes a eltwise_desc for backward propagation using alg_kind algorithm memory descriptors diff_...
desc(algorithm aalgorithm, const memory::desc &data_desc, const memory::desc &diff_data_desc, int local_size, float alpha, float beta)
Definition: mkldnn.hpp:2079
5D data tensor with the physical layout ncdhw.
Definition: mkldnn_types.h:166
Definition: mkldnn.hpp:3185
mkldnn_status_t MKLDNN_API mkldnn_primitive_desc_iterator_destroy(mkldnn_primitive_desc_iterator_t iterator)
Deletes a primitive descriptor iterator.
5D RNN states tensor in the format (num_layers, num_directions, num_states, batch, state channels).
Definition: mkldnn_types.h:224
Definition: mkldnn.hpp:2091
size_t get_size() const
Returns the number of bytes required to allocate the memory described including the padding area...
Definition: mkldnn.hpp:774
mkldnn_status_t MKLDNN_API mkldnn_post_ops_append_sum(mkldnn_post_ops_t post_ops, float scale)
Appends accumulation (sum) post operation to the post_ops.
Definition: mkldnn.hpp:1530
deconvolution_forward(const primitive_desc &aprimitive_desc, const primitive::at &src, const primitive::at &weights, const primitive::at &bias, const memory &dst)
Definition: mkldnn.hpp:1764
A rnn primitive.
Definition: mkldnn_types.h:438
mkldnn_status_t MKLDNN_API mkldnn_primitive_get_output(const_mkldnn_primitive_t primitive, size_t index, const_mkldnn_primitive_t *output)
For a primitive, returns output at the index position.
blocked weights format
Definition: mkldnn_types.h:299
mkldnn_status_t MKLDNN_API mkldnn_shuffle_backward_desc_init(mkldnn_shuffle_desc_t *shuffle_desc, const mkldnn_memory_desc_t *diff_data_desc, int axis, int group_size)
Initializes a shuffle_desc for backward propagation using memory descriptor diff_data_desc, axis and group number.
mkldnn_deconvolution_desc_t data
Definition: mkldnn.hpp:1868
Definition: mkldnn.hpp:2954
eltwise_backward(const primitive_desc &aprimitive_desc, const primitive::at &src, const primitive::at &diff_dst, const memory &diff_src)
Definition: mkldnn.hpp:2344
mkldnn_prop_kind_t
Kinds of propagation.
Definition: mkldnn_types.h:377
A wrapper structure to specify a particular output of a primitive.
Definition: mkldnn.hpp:134
CPU engine.
Definition: mkldnn_types.h:999
Definition: mkldnn.hpp:291
desc(algorithm alg_kind, const memory::desc &diff_data_desc, const memory::desc &data_desc, T alpha=0, T beta=0)
Definition: mkldnn.hpp:2320
Eltwise: square root.
Definition: mkldnn_types.h:465
blocked weights format
Definition: mkldnn_types.h:263
mkldnn_stream_kind_t
Kinds of streams.
Definition: mkldnn_types.h:1196
Definition: mkldnn.hpp:271
mkldnn_status_t MKLDNN_API mkldnn_primitive_attr_set_int_output_round_mode(mkldnn_primitive_attr_t attr, mkldnn_round_mode_t round_mode)
Sets output rounding mode round_mode for integer operations for a given attr.
4D weights tensor with physical layout hwio, used in TensorFlow.
Definition: mkldnn_types.h:187
A wrapper structure to specify a particular output of a primitive.
Definition: mkldnn_types.h:1107
Winograd convolution.
Definition: mkldnn_types.h:447
Definition: mkldnn.hpp:246
Definition: mkldnn.hpp:343
Eltwise: linear.
Definition: mkldnn_types.h:467
desc(algorithm aalgorithm, const memory::desc &diff_src_desc, const memory::desc &weights_desc, const memory::desc &diff_dst_desc, const memory::dims strides, const memory::dims padding_l, const memory::dims padding_r, const padding_kind apadding_kind)
Definition: mkldnn.hpp:1797
mkldnn_status_t MKLDNN_API mkldnn_softmax_backward_desc_init(mkldnn_softmax_desc_t *softmax_desc, const mkldnn_memory_desc_t *diff_desc, const mkldnn_memory_desc_t *data_desc, int softmax_axis)
Initializes a softmax_desc for backward propagation using memory descriptors diff_desc and data_desc...
desc(algorithm aalgorithm, const memory::desc &src_desc, const memory::desc &diff_weights_desc, const memory::desc &diff_bias_desc, const memory::desc &diff_dst_desc, const memory::dims strides, const memory::dims padding_l, const memory::dims padding_r, const padding_kind apadding_kind)
Definition: mkldnn.hpp:1869
reorder(const primitive::at &input, const memory &output)
Definition: mkldnn.hpp:977
Eltwise: logistic.
Definition: mkldnn_types.h:473
Definition: mkldnn.hpp:2645
Direct convolution.
Definition: mkldnn_types.h:445
Primitive iterator passed over last primitive descriptor.
Definition: mkldnn_types.h:54
Definition: mkldnn.hpp:338
Definition: mkldnn.hpp:270
lrn_forward(const primitive_desc &aprimitive_desc, const primitive::at &src, const memory &workspace, const memory &dst)
Definition: mkldnn.hpp:2039
source gradient memory primitive desc
Definition: mkldnn_types.h:1182
mkldnn_alg_kind_t cell_kind
RNN cell kind.
Definition: mkldnn_types.h:918
Definition: mkldnn.hpp:1458
mkldnn_batch_normalization_desc_t data
Definition: mkldnn.hpp:2647
Definition: mkldnn_types.h:935
An opaque structure for primitive descriptor attributes.
Definition: mkldnn.hpp:312
blocked data format
Definition: mkldnn_types.h:255
mkldnn_status_t MKLDNN_API mkldnn_pooling_forward_desc_init(mkldnn_pooling_desc_t *pool_desc, mkldnn_prop_kind_t prop_kind, mkldnn_alg_kind_t alg_kind, const mkldnn_memory_desc_t *src_desc, const mkldnn_memory_desc_t *dst_desc, const mkldnn_dims_t strides, const mkldnn_dims_t kernel, const mkldnn_dims_t padding_l, const mkldnn_dims_t padding_r, mkldnn_padding_kind_t padding_kind)
Initializes a pooling descriptor pool_desc for forward propagation using prop_kind (possible values a...
desc(prop_kind aprop_kind, algorithm aalgorithm, const memory::desc &src_desc, int local_size, float alpha, float beta, float k)
Definition: mkldnn.hpp:2007
batch_normalization_forward(const primitive_desc &aprimitive_desc, const primitive::at &src, const primitive::at &weights, const memory &dst)
Definition: mkldnn.hpp:2617
mkldnn_rnn_cell_desc_t c_rnn_cell_
Definition: mkldnn.hpp:2955
bool operator!=(const primitive_desc &other) const
Definition: mkldnn.hpp:783
runtime estimation (seconds)
Definition: mkldnn_types.h:1155
blocked weights format
Definition: mkldnn_types.h:342
bool operator==(const T other) const
Definition: mkldnn.hpp:61
A (in-place) concat primitive.
Definition: mkldnn_types.h:418
mkldnn_status_t MKLDNN_API mkldnn_stream_create(mkldnn_stream_t *stream, mkldnn_stream_kind_t stream_kind)
Creates an execution stream of stream_kind.
primitive_desc get_primitive_desc() const
Returns the descriptor of the memory primitive.
Definition: mkldnn.hpp:834
blocked weights format
Definition: mkldnn_types.h:277
LSTM cell.
Definition: mkldnn_types.h:488
blocked weights format
Definition: mkldnn_types.h:266
mkldnn_status_t MKLDNN_API mkldnn_batch_normalization_backward_desc_init(mkldnn_batch_normalization_desc_t *bnrm_desc, mkldnn_prop_kind_t prop_kind, const mkldnn_memory_desc_t *diff_data_desc, const mkldnn_memory_desc_t *data_desc, float epsilon, unsigned flags)
Initializes a batch normalization descriptor bnrm_desc for backward propagation with respect to data ...
Definition: mkldnn_types.h:944
primitive_desc(const desc &desc, const engine &e)
Definition: mkldnn.hpp:2464
primitive_desc(const desc &desc, const engine &e)
Definition: mkldnn.hpp:2791
primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e)
Definition: mkldnn.hpp:2794
Undefined data type, used for empty memory descriptors.
Definition: mkldnn_types.h:64
Definition: mkldnn.hpp:1794
16-bit signed integer.
Definition: mkldnn_types.h:70
Definition: mkldnn.hpp:2278
A shuffle primitive.
Definition: mkldnn_types.h:414
blocked weights format with additional buffer with size equal to the number of output channels and co...
Definition: mkldnn_types.h:284
mkldnn_shuffle_desc_t data
Definition: mkldnn.hpp:3219
primitive_desc()
Definition: mkldnn.hpp:755
int len() const
Definition: mkldnn.hpp:375
mkldnn_status_t MKLDNN_API mkldnn_primitive_get_primitive_desc(const_mkldnn_primitive_t primitive, const_mkldnn_primitive_desc_t *primitive_desc)
Retrieves a reference to the primitive_desc descriptor of given primitive.
primitive_desc(const memory::desc &output, const std::vector< float > &scales, std::vector< memory::primitive_desc > inputs)
Definition: mkldnn.hpp:1147
desc(prop_kind aprop_kind, const memory::desc &src_desc, const memory::desc &weights_desc, const memory::desc &dst_desc)
Definition: mkldnn.hpp:2779
mkldnn_status_t MKLDNN_API mkldnn_post_ops_get_params_eltwise(const_mkldnn_post_ops_t post_ops, int index, float *scale, mkldnn_alg_kind_t *alg, float *alpha, float *beta)
Gets the eltwise parameters of the post operation with index index in the sequence of post_ops...
Definition: mkldnn.hpp:242
blocked weights format
Definition: mkldnn_types.h:305
mkldnn_status_t MKLDNN_API mkldnn_post_ops_get_params_sum(const_mkldnn_post_ops_t post_ops, int index, float *scale)
Gets the parameters of the accumulation (sum) post operation with index index in the sequence of post...
mkldnn_convolution_desc_t data
Definition: mkldnn.hpp:1459
blocked weights format
Definition: mkldnn_types.h:298
A (out-of-place) concat primitive.
Definition: mkldnn_types.h:416
blocked weights format
Definition: mkldnn_types.h:312
Fuse with ReLU.
Definition: mkldnn_types.h:538
Definition: mkldnn.hpp:260
Definition: mkldnn.hpp:278
static size_t get_count(kind akind)
Returns the number of engines of a certain kind.
Definition: mkldnn.hpp:519
mkldnn_query_t
Primitive descriptor query specification.
Definition: mkldnn_types.h:1146
A descriptor of a Batch Normalization operation.
Definition: mkldnn_types.h:849
static engine query(const primitive_desc &pd)
Definition: mkldnn.hpp:551
Definition: mkldnn.hpp:2993
deconvolution_backward_weights(const primitive_desc &aprimitive_desc, const primitive::at &src, const primitive::at &diff_dst, const memory &diff_weights)
Definition: mkldnn.hpp:1980
blocked data format
Definition: mkldnn_types.h:254
A sum primitive.
Definition: mkldnn_types.h:420
batch_normalization_backward(const primitive_desc &aprimitive_desc, const primitive::at &src, const primitive::at &mean, const primitive::at &variance, const primitive::at &diff_dst, const memory &diff_src)
Definition: mkldnn.hpp:2740
Definition: mkldnn.hpp:302
blocked weights format
Definition: mkldnn_types.h:339
eltwise_forward(const primitive_desc &aprimitive_desc, const primitive::at &src, const memory &dst)
Definition: mkldnn.hpp:2302
unsigned flags
Definition: mkldnn_types.h:876
mkldnn_status_t MKLDNN_API mkldnn_reorder_primitive_desc_create_v2(mkldnn_primitive_desc_t *reorder_primitive_desc, const_mkldnn_primitive_desc_t input, const_mkldnn_primitive_desc_t output, const_mkldnn_primitive_attr_t attr)
Initializes a reorder_primitive_desc using an attr attribute and descriptors of input and output memo...
blocked weights format
Definition: mkldnn_types.h:267
blocked weights format
Definition: mkldnn_types.h:316
Definition: mkldnn.hpp:2953
Convolution algorithm(either direct or Winograd) is chosen just in time.
Definition: mkldnn_types.h:449
softmax_backward(const primitive_desc &aprimitive_desc, const primitive::at &dst, const primitive::at &diff_dst, const memory &diff_src)
Definition: mkldnn.hpp:2428
blocked weights format
Definition: mkldnn_types.h:258
Definition: mkldnn.hpp:2994
Definition: mkldnn.hpp:258
primitive_desc(const desc &desc, const engine &e)
Definition: mkldnn.hpp:2292
mkldnn_status_t MKLDNN_API mkldnn_dilated_deconvolution_backward_data_desc_init(mkldnn_deconvolution_desc_t *conv_desc, mkldnn_alg_kind_t alg_kind, const mkldnn_memory_desc_t *diff_src_desc, const mkldnn_memory_desc_t *weights_desc, const mkldnn_memory_desc_t *diff_dst_desc, const mkldnn_dims_t strides, const mkldnn_dims_t dilates, const mkldnn_dims_t padding_l, const mkldnn_dims_t padding_r, mkldnn_padding_kind_t padding_kind)
Initializes a dilated deconvolution descriptor conv_desc for backward propagation with respect to dat...
blocked weights format
Definition: mkldnn_types.h:344
mkldnn_status_t MKLDNN_API mkldnn_stream_rerun(mkldnn_stream_t stream, mkldnn_primitive_t *error_primitive)
Reruns all the primitives within the stream.
2D weights tensor with physical layout io.
Definition: mkldnn_types.h:175
memory consumption – extra (scratch) memory, additional to all inputs and outputs memory (bytes) ...
Definition: mkldnn_types.h:1156
An batch normalization primitive.
Definition: mkldnn_types.h:434
A class for wrapping an Intel(R) MKL-DNN handle. It is used as the base class for primitive (mkldnn_p...
Definition: mkldnn.hpp:55
Definition: mkldnn_types.h:443
engine(kind akind, size_t index)
Constructs an engine.
Definition: mkldnn.hpp:529
Definition: mkldnn.hpp:2277
A descriptor of a pooling operation.
Definition: mkldnn_types.h:788
Definition: mkldnn.hpp:3265
Definition: mkldnn.hpp:272
Definition: mkldnn.hpp:273
engine get_engine()
Definition: mkldnn.hpp:787
error(mkldnn_status_t astatus, std::string amessage, mkldnn_primitive_t aerror_primitive=0)
Constructs an error instance.
Definition: mkldnn.hpp:173
primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e, const deconvolution_forward::primitive_desc &hint_fwd_pd)
Definition: mkldnn.hpp:1956
const char * impl_info_str() const
Returns implementation name.
Definition: mkldnn.hpp:1256
deconvolution descriptor
Definition: mkldnn_types.h:1167
std::vector< const_mkldnn_primitive_desc_t > cpp_to_c(std::vector< memory::primitive_desc > inputs)
Definition: mkldnn.hpp:1137
blocked weights format
Definition: mkldnn_types.h:318
shuffle_backward(const primitive_desc &aprimitive_desc, const primitive::at &diff_dst, const memory &diff_src)
Definition: mkldnn.hpp:3236
primitive_desc(const memory::primitive_desc &input, const memory::primitive_desc &output)
Definition: mkldnn.hpp:944
primitive_desc(const desc &desc, const engine &e, const pooling_forward::primitive_desc &hint_fwd_pd)
Definition: mkldnn.hpp:2230
mkldnn_memory_desc_t data
The underlying C API data structure.
Definition: mkldnn.hpp:727
mkldnn_primitive_desc_t MKLDNN_API mkldnn_primitive_desc_iterator_fetch(const_mkldnn_primitive_desc_iterator_t iterator)
Fetches current primitive descriptor.
primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e)
Definition: mkldnn.hpp:1420
engine get_engine()
Definition: mkldnn.hpp:963
int MKLDNN_API mkldnn_primitive_desc_query_s32(const_mkldnn_primitive_desc_t primitive_desc, mkldnn_query_t what, int index)
Queries primitive descriptor for signed 32bit int.
8-bit signed integer.
Definition: mkldnn_types.h:72
mkldnn_status_t MKLDNN_API mkldnn_reorder_primitive_desc_create(mkldnn_primitive_desc_t *reorder_primitive_desc, const_mkldnn_primitive_desc_t input, const_mkldnn_primitive_desc_t output)
Initializes a reorder_primitive_desc using descriptors of input and output memory primitives...
The data in padding regions is zero.
Definition: mkldnn_types.h:373
int MKLDNN_API mkldnn_rnn_cell_get_states_count(const mkldnn_rnn_cell_desc_t *rnn_cell_desc)
Returns the number of states of a particular rnn_cell_desc.
Definition: mkldnn.hpp:2291
desc(const memory::desc &src_desc, const memory::desc &diff_weights_desc, const memory::desc &diff_dst_desc)
Definition: mkldnn.hpp:2889
source memory primitive desc
Definition: mkldnn_types.h:1181
mkldnn_primitive_kind_t
Kinds of primitives.
Definition: mkldnn_types.h:404
primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e, const deconvolution_forward::primitive_desc &hint_fwd_pd)
Definition: mkldnn.hpp:1842
desc(algorithm aalgorithm, const memory::desc &src_desc, const memory::desc &diff_weights_desc, const memory::desc &diff_dst_desc, const memory::dims strides, const memory::dims dilates, const memory::dims padding_l, const memory::dims padding_r, const padding_kind apadding_kind)
Definition: mkldnn.hpp:1929
Definition: mkldnn.hpp:3196
Winograd deconvolution.
Definition: mkldnn_types.h:453
Definition: mkldnn.hpp:248
number of inputs expected
Definition: mkldnn_types.h:1152
mkldnn_softmax_desc_t data
Definition: mkldnn.hpp:2368
Definition: mkldnn.hpp:345
Definition: mkldnn.hpp:3018
primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e)
Definition: mkldnn.hpp:2467
desc(prop_kind aprop_kind, algorithm alg_kind, const memory::desc &src_desc, T alpha=0, T beta=0)
Definition: mkldnn.hpp:2281
An unspecified engine.
Definition: mkldnn_types.h:1198
primitive_desc(const desc &desc, const engine &e)
Definition: mkldnn.hpp:1752
void * get_data_handle() const
Returns a handle of the data contained in the memory primitive. On the CPU engine, this is a pointer to the allocated memory.
Definition: mkldnn.hpp:847
A view primitive.
Definition: mkldnn_types.h:410
size_t MKLDNN_API mkldnn_memory_primitive_desc_get_size(const_mkldnn_primitive_desc_t memory_primitive_desc)
Returns the size (in bytes) that is required for given memory_primitive_desc.
Definition: mkldnn.hpp:3065
Definition: mkldnn.hpp:262
Definition: mkldnn.hpp:328
Definition: mkldnn.hpp:3099
blocked weights format
Definition: mkldnn_types.h:290
mkldnn_primitive_kind_t convert_to_c(primitive::kind akind)
Definition: mkldnn.hpp:154
Definition: mkldnn.hpp:340
Definition: mkldnn.hpp:331
Definition: mkldnn.hpp:323
Definition: mkldnn.hpp:333
Average pooling exclude padding.
Definition: mkldnn_types.h:479
mkldnn_status_t MKLDNN_API mkldnn_primitive_attr_get_post_ops(const_mkldnn_primitive_attr_t attr, const_mkldnn_post_ops_t *post_ops)
Returns post_ops for given attr.
mkldnn_status_t MKLDNN_API mkldnn_primitive_create(mkldnn_primitive_t *primitive, const_mkldnn_primitive_desc_t primitive_desc, const mkldnn_primitive_at_t *inputs, const_mkldnn_primitive_t *outputs)
Creates a primitive using a primitive_desc descriptor and arrays of inputs and outputs.
primitive::kind kind(int index) const
Definition: mkldnn.hpp:377
Definition: mkldnn_types.h:914
Forward data propagation (inference mode).
Definition: mkldnn_types.h:387
primitive_attr get_primitive_attr() const
Definition: mkldnn.hpp:1242
6D grouped weights tensor with the physical layout goidhw, used in Caffe.
Definition: mkldnn_types.h:217
5D weights tensor with physical layout iodhw, used in Caffe.
Definition: mkldnn_types.h:196
A class that provides the destructor for an Intel(R) MKL-DNN C handle.
Definition: mkldnn.hpp:40
data_type
Data type specification. See mkldnn_data_type_t for a detailed description.
Definition: mkldnn.hpp:594
batch_normalization_forward(const primitive_desc &aprimitive_desc, const primitive::at &src, const primitive::at &mean, const primitive::at &variance, const memory &dst)
Definition: mkldnn.hpp:2507
Direct deconvolution.
Definition: mkldnn_types.h:451
Eltwise: abs.
Definition: mkldnn_types.h:463
batch_normalization_forward(const primitive_desc &aprimitive_desc, const primitive::at &src, const primitive::at &weights, const memory &dst, const memory &mean, const memory &variance)
Definition: mkldnn.hpp:2529
blocked weights format
Definition: mkldnn_types.h:328
pooling_backward(const primitive_desc &aprimitive_desc, const primitive::at &diff_dst, const primitive::at &workspace, const memory &diff_src)
Definition: mkldnn.hpp:2255
blocked weights format
Definition: mkldnn_types.h:278
A memory descriptor.
Definition: mkldnn.hpp:724
deconvolution_backward_data(const primitive_desc &aprimitive_desc, const primitive::at &diff_dst, const primitive::at &weights, const memory &diff_src)
Definition: mkldnn.hpp:1851
5D grouped weights tensor with the physical layout hwigo, used in TensorFlow.
Definition: mkldnn_types.h:210
primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e)
Definition: mkldnn.hpp:2295
blocked weights format
Definition: mkldnn_types.h:336
bool operator!=(mkldnn_data_type_t a, memory::data_type b)
Definition: mkldnn.hpp:911
void set_rnn_weights_qparams(int mask, const std::vector< float > &scales)
Definition: mkldnn.hpp:480
handle(T t=0, bool weak=false)
Constructs a C handle wrapper.
Definition: mkldnn.hpp:67
mkldnn_status_t MKLDNN_API mkldnn_dilated_convolution_forward_desc_init(mkldnn_convolution_desc_t *conv_desc, mkldnn_prop_kind_t prop_kind, mkldnn_alg_kind_t alg_kind, const mkldnn_memory_desc_t *src_desc, const mkldnn_memory_desc_t *weights_desc, const mkldnn_memory_desc_t *bias_desc, const mkldnn_memory_desc_t *dst_desc, const mkldnn_dims_t strides, const mkldnn_dims_t dilates, const mkldnn_dims_t padding_l, const mkldnn_dims_t padding_r, mkldnn_padding_kind_t padding_kind)
Initializes a dilated convolution descriptor conv_desc for forward propagation using prop_kind (possi...
Eltwise: hyperbolic tangent non-linearity (tanh)
Definition: mkldnn_types.h:457
mkldnn_inner_product_desc_t data
Definition: mkldnn.hpp:2878
mkldnn_status_t status
Definition: mkldnn.hpp:162
deconvolution_forward(const primitive_desc &aprimitive_desc, const primitive::at &src, const primitive::at &weights, const memory &dst)
Definition: mkldnn.hpp:1779
T get() const
Returns the value of the underlying C handle.
Definition: mkldnn.hpp:85
mkldnn_status_t MKLDNN_API mkldnn_engine_destroy(mkldnn_engine_t engine)
Destroys an engine.
view(const primitive_desc &view_pd, primitive::at input)
Definition: mkldnn.hpp:1029
desc(algorithm aalgorithm, const memory::desc &src_desc, const memory::desc &diff_weights_desc, const memory::desc &diff_dst_desc, const memory::dims strides, const memory::dims padding_l, const memory::dims padding_r, const padding_kind apadding_kind)
Definition: mkldnn.hpp:1889
blocked weights format
Definition: mkldnn_types.h:317
2D data tensor.
Definition: mkldnn_types.h:148
primitive_desc(const desc &desc, const engine &e, const batch_normalization_forward::primitive_desc &hint_fwd_pd)
Definition: mkldnn.hpp:2661
desc(prop_kind aprop_kind, const memory::desc &src_desc, const memory::desc &weights_desc, const memory::desc &bias_desc, const memory::desc &dst_desc)
Definition: mkldnn.hpp:2768
mkldnn_status_t MKLDNN_API mkldnn_dilated_convolution_backward_data_desc_init(mkldnn_convolution_desc_t *conv_desc, mkldnn_alg_kind_t alg_kind, const mkldnn_memory_desc_t *diff_src_desc, const mkldnn_memory_desc_t *weights_desc, const mkldnn_memory_desc_t *diff_dst_desc, const mkldnn_dims_t strides, const mkldnn_dims_t dilates, const mkldnn_dims_t padding_l, const mkldnn_dims_t padding_r, mkldnn_padding_kind_t padding_kind)
Initializes a dilated convolution descriptor conv_desc for backward propagation with respect to data ...
bool wait(bool block=true)
Waits for all computations submitted to the stream to complete.
Definition: mkldnn.hpp:3315
mkldnn_status_t MKLDNN_API mkldnn_lrn_backward_desc_init(mkldnn_lrn_desc_t *lrn_desc, mkldnn_alg_kind_t alg_kind, const mkldnn_memory_desc_t *diff_data_desc, const mkldnn_memory_desc_t *data_desc, int local_size, float alpha, float beta, float k)
Initializes an lrn_desc for backward propagation using alg_kind, memory descriptors data_desc...
Primitive or engine failed on execution.
Definition: mkldnn_types.h:56
memory descriptor for memory and view
Definition: mkldnn_types.h:1165
view(memory input, memory::dims dims, memory::dims offsets)
Definition: mkldnn.hpp:1038
Definition: mkldnn.hpp:266
An LRN primitive.
Definition: mkldnn_types.h:432
Definition: mkldnn_types.h:940
mkldnn_padding_kind_t
Kinds of padding.
Definition: mkldnn_types.h:371
rnn_backward(const primitive_desc &aprimitive_desc, const primitive::at &src_layer, const primitive::at &src_iter, const primitive::at &weights_layer, const primitive::at &weights_iter, const primitive::at &bias, const primitive::at &dst_layer, const primitive::at &dst_iter, const memory &diff_src_layer, const memory &diff_src_iter, const memory &diff_weights_layer, const memory &diff_weights_iter, const memory &diff_bias, const primitive::at &diff_dst_layer, const primitive::at &diff_dst_iter, const primitive::at &workspace)
Definition: mkldnn.hpp:3127
Lazy stream.
Definition: mkldnn_types.h:1202
Definition: mkldnn.hpp:332
desc(const memory::desc &diff_desc, const memory::desc &data_desc, int softmax_axis)
Definition: mkldnn.hpp:2405
blocked weights format
Definition: mkldnn_types.h:340
Definition: mkldnn.hpp:304
void get_output_scales(int &mask, std::vector< float > &scales) const
Definition: mkldnn.hpp:439
blocked weights format
Definition: mkldnn_types.h:260
desc(algorithm kind)
Definition: mkldnn.hpp:2963
primitive_desc(const desc &desc, const engine &e, const rnn_forward::primitive_desc &hint_fwd_pd)
Definition: mkldnn.hpp:3100
5D RNN weights tensor in the format (num_layers, num_directions, num_gates, output_channels, input_channels).
Definition: mkldnn_types.h:238
blocked weights format
Definition: mkldnn_types.h:310
const_mkldnn_primitive_desc_t MKLDNN_API mkldnn_primitive_desc_query_pd(const_mkldnn_primitive_desc_t primitive_desc, mkldnn_query_t what, int index)
Queries primitive descriptor for primitive descriptor.
Definition: mkldnn.hpp:2876
shuffle descriptor
Definition: mkldnn_types.h:1168
Forward data propagation (training mode).
Definition: mkldnn_types.h:383
Definition: mkldnn.hpp:344
primitive_desc(const desc &desc, const engine &e, const lrn_forward::primitive_desc &hint_fwd_pd)
Definition: mkldnn.hpp:2092
inner_product_backward_weights(const primitive_desc &aprimitive_desc, const primitive::at &src, const primitive::at diff_dst, const memory &diff_weights)
Definition: mkldnn.hpp:2915
mkldnn_convolution_desc_t data
Definition: mkldnn.hpp:1532
memory(const primitive &aprimitive)
Constructs a memory primitive from a generic primitive.
Definition: mkldnn.hpp:793
3D data tensor with the physical layout nwc.
Definition: mkldnn_types.h:154
engine get_engine()
Definition: mkldnn.hpp:1108
post_ops()
Definition: mkldnn.hpp:368
An opaque structure to describe a primitive.
batch_normalization_backward(const primitive_desc &aprimitive_desc, const primitive::at &src, const primitive::at &mean, const primitive::at &variance, const primitive::at &diff_dst, const primitive::at &weights, const primitive::at &workspace, const memory &diff_src, const memory &diff_weights)
Definition: mkldnn.hpp:2701
A tensor in a generic format described by the stride and blocking values in each dimension.
Definition: mkldnn_types.h:144
desc(prop_kind aprop_kind, algorithm aalgorithm, const memory::desc &src_desc, const memory::desc &weights_desc, const memory::desc &bias_desc, const memory::desc &dst_desc, const memory::dims strides, const memory::dims padding_l, const memory::dims padding_r, const padding_kind apadding_kind)
Definition: mkldnn.hpp:1332
mkldnn_data_type_t
Data type specification.
Definition: mkldnn_types.h:62
Definition: mkldnn.hpp:1457
Definition: mkldnn.hpp:325
Definition: mkldnn.hpp:318
convolution descriptor
Definition: mkldnn_types.h:1166
primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e, const convolution_forward::primitive_desc &hint_fwd_pd)
Definition: mkldnn.hpp:1506
A memory primitive descriptor.
Definition: mkldnn.hpp:751
Definition: mkldnn.hpp:314
Definition: mkldnn.hpp:2413
mkldnn_status_t MKLDNN_API mkldnn_lrn_forward_desc_init(mkldnn_lrn_desc_t *lrn_desc, mkldnn_prop_kind_t prop_kind, mkldnn_alg_kind_t alg_kind, const mkldnn_memory_desc_t *data_desc, int local_size, float alpha, float beta, float k)
Initializes an lrn_desc for forward propagation using prop_kind (possible values are mkldnn_forward_t...
blocked weights format
Definition: mkldnn_types.h:301
primitive_desc(const desc &desc, const engine &e, const convolution_forward::primitive_desc &hint_fwd_pd)
Definition: mkldnn.hpp:1502
blocked weights format
Definition: mkldnn_types.h:294
handle & operator=(const handle &other)
Definition: mkldnn.hpp:72
batch_normalization_forward(const primitive_desc &aprimitive_desc, const primitive::at &src, const memory &dst)
Definition: mkldnn.hpp:2631
Eltwise: bounded_relu.
Definition: mkldnn_types.h:469
Definition: mkldnn.hpp:2367
#define REG_QUERY_MPD(name, what, idx)
Definition: mkldnn.hpp:1306
Definition: mkldnn_types.h:937
convolution_forward(const primitive_desc &aprimitive_desc, const primitive::at &src, const primitive::at &weights, const memory &dst)
Definition: mkldnn.hpp:1442
mkldnn_engine_kind_t
Kinds of engines.
Definition: mkldnn_types.h:995
Definition: mkldnn_types.h:910
int MKLDNN_API mkldnn_rnn_cell_get_gates_count(const mkldnn_rnn_cell_desc_t *rnn_cell_desc)
Returns the number of gates of a particular rnn_cell_desc.
Queried element is not required for given primitive.
Definition: mkldnn_types.h:58
primitive_desc(const desc &desc, const engine &e)
Definition: mkldnn.hpp:3019
blocked weights format
Definition: mkldnn_types.h:357
bool operator!=(const T other) const
Definition: mkldnn.hpp:62
Memory primitive that describes the data.
Definition: mkldnn.hpp:579
Weights format used in 8bit Winograd convolution.
Definition: mkldnn_types.h:361
Definition: mkldnn.hpp:327
primitive_desc(const desc &desc, const engine &e)
Definition: mkldnn.hpp:2028
Definition: mkldnn.hpp:2066
Definition: mkldnn.hpp:301
Round nearest.
Definition: mkldnn_types.h:80
blocked weights format
Definition: mkldnn_types.h:356
Definition: mkldnn.hpp:243
batch_normalization_backward(const primitive_desc &aprimitive_desc, const primitive::at &src, const primitive::at &mean, const primitive::at &variance, const primitive::at &diff_dst, const primitive::at &weights, const memory &diff_src, const memory &diff_weights)
Definition: mkldnn.hpp:2682
Definition: mkldnn.hpp:1668
const void * const_mkldnn_op_desc_t
A pointer to any of the operation descriptors (constant variant).
Definition: mkldnn_types.h:628
static mkldnn_stream_kind_t convert_to_c(kind akind)
Definition: mkldnn.hpp:3272
blocked weights format
Definition: mkldnn_types.h:259
blocked weights format
Definition: mkldnn_types.h:353
Definition: mkldnn.hpp:1866
memory::primitive_desc dst_primitive_desc() const
Definition: mkldnn.hpp:1096
mkldnn_status_t MKLDNN_API mkldnn_primitive_desc_iterator_create_v2(mkldnn_primitive_desc_iterator_t *iterator, const_mkldnn_op_desc_t op_desc, const_mkldnn_primitive_attr_t attr, mkldnn_engine_t engine, const_mkldnn_primitive_desc_t hint_forward_primitive_desc)
Creates a primitive descriptor iterator for given op_desc, attr, engine, and optionally a hint primit...
Definition: mkldnn.hpp:2449
pooling_forward(const primitive_desc &aprimitive_desc, const primitive::at &src, const memory &dst, const memory &workspace)
Definition: mkldnn.hpp:2191
convolution_forward(const primitive_desc &aprimitive_desc, const primitive::at &src, const primitive::at &weights, const primitive::at &bias, const memory &dst)
Definition: mkldnn.hpp:1429
4D weights tensor with physical layout iohw.
Definition: mkldnn_types.h:193
A reorder primitive.
Definition: mkldnn_types.h:412
primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e)
Definition: mkldnn.hpp:1755
rnn_direction
Definition: mkldnn.hpp:299
primitive_desc(const std::vector< float > &scales, std::vector< memory::primitive_desc > inputs)
Definition: mkldnn.hpp:1166
blocked weights format
Definition: mkldnn_types.h:337
blocked weights format
Definition: mkldnn_types.h:297
An unspecified engine.
Definition: mkldnn_types.h:997
desc(const mkldnn_memory_desc_t &adata)
Constructs a memory descriptor from a C API data structure.
Definition: mkldnn.hpp:747
blocked weights format
Definition: mkldnn_types.h:313
Definition: mkldnn.hpp:1136
int MKLDNN_API mkldnn_post_ops_len(const_mkldnn_post_ops_t post_ops)
Returns the length of post operations for given post_ops.
engine get_engine()
Definition: mkldnn.hpp:1026
primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e, const pooling_forward::primitive_desc &hint_fwd_pd)
Definition: mkldnn.hpp:2234
blocked weights format
Definition: mkldnn_types.h:338
blocked weights format
Definition: mkldnn_types.h:327
mkldnn_alg_kind_t
Kinds of algorithms.
Definition: mkldnn_types.h:442
primitive_desc(const desc &desc, const engine &e, const inner_product_forward::primitive_desc &hint_fwd_pd)
Definition: mkldnn.hpp:2901
Definition: mkldnn.hpp:263
inner product descriptor
Definition: mkldnn_types.h:1174
A pooling primitive.
Definition: mkldnn_types.h:430
weights memory primitive descriptor desc
Definition: mkldnn_types.h:1183
output memory primitive desc
Definition: mkldnn_types.h:1180
Definition: mkldnn.hpp:2229
5D weights tensor with physical layout dhwio, used in TensorFlow.
Definition: mkldnn_types.h:199
primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e)
Definition: mkldnn.hpp:2031
mkldnn_batch_normalization_desc_t data
Definition: mkldnn.hpp:2451
Definition: mkldnn.hpp:943
mkldnn_status_t MKLDNN_API mkldnn_primitive_destroy(mkldnn_primitive_t primitive)
Deletes a primitive.
Definition: mkldnn.hpp:334
std::string message
Definition: mkldnn.hpp:163
Definition: mkldnn.hpp:3184
mkldnn_status_t MKLDNN_API mkldnn_deconvolution_backward_weights_desc_init(mkldnn_deconvolution_desc_t *conv_desc, mkldnn_alg_kind_t alg_kind, const mkldnn_memory_desc_t *src_desc, const mkldnn_memory_desc_t *diff_weights_desc, const mkldnn_memory_desc_t *diff_bias_desc, const mkldnn_memory_desc_t *diff_dst_desc, const mkldnn_dims_t strides, const mkldnn_dims_t padding_l, const mkldnn_dims_t padding_r, mkldnn_padding_kind_t padding_kind)
Initializes a deconvolution descriptor conv_desc for backward propagation with respect to weights usi...
mkldnn_status_t MKLDNN_API mkldnn_rnn_backward_desc_init(mkldnn_rnn_desc_t *rnn_desc, mkldnn_prop_kind_t prop_kind, const mkldnn_rnn_cell_desc_t *rnn_cell_desc, const mkldnn_rnn_direction_t direction, const mkldnn_memory_desc_t *src_layer_desc, const mkldnn_memory_desc_t *src_iter_desc, const mkldnn_memory_desc_t *weights_layer_desc, const mkldnn_memory_desc_t *weights_iter_desc, const mkldnn_memory_desc_t *bias_desc, const mkldnn_memory_desc_t *dst_layer_desc, const mkldnn_memory_desc_t *dst_iter_desc, const mkldnn_memory_desc_t *diff_src_layer_desc, const mkldnn_memory_desc_t *diff_src_iter_desc, const mkldnn_memory_desc_t *diff_weights_layer_desc, const mkldnn_memory_desc_t *diff_weights_iter_desc, const mkldnn_memory_desc_t *diff_bias_desc, const mkldnn_memory_desc_t *diff_dst_layer, const mkldnn_memory_desc_t *diff_dst_iter_desc)
Initializes a rnn descriptor rnn_desc for backward propagation using prop_kind, rnn_cell_desc, direction, and memory descriptors.
primitive_desc(const desc &desc, const engine &e, const eltwise_forward::primitive_desc &hint_fwd_pd)
Definition: mkldnn.hpp:2331
Definition: mkldnn.hpp:315
blocked weights format
Definition: mkldnn_types.h:287
handle(const handle &other)
Definition: mkldnn.hpp:71
Forward data propagation (alias for mkldnn_forward_training)
Definition: mkldnn_types.h:391
3D RNN data tensor in the format (batch, seq_length, input channels).
Definition: mkldnn_types.h:219
mkldnn_status_t MKLDNN_API mkldnn_primitive_attr_set_output_scales(mkldnn_primitive_attr_t attr, int count, int mask, const float *scales)
Sets output scales for primitive operations.
Definition: mkldnn.hpp:241
lrn descriptor
Definition: mkldnn_types.h:1172
workspace memory primitive desc
Definition: mkldnn_types.h:1187
lrn_backward(const primitive_desc &aprimitive_desc, const primitive::at &src, const primitive::at &diff_dst, const memory &diff_src)
Definition: mkldnn.hpp:2119
desc(algorithm aalgorithm, const memory::desc &src_desc, const memory::desc &diff_weights_desc, const memory::desc &diff_dst_desc, const memory::dims strides, const memory::dims dilates, const memory::dims padding_l, const memory::dims padding_r, const padding_kind apadding_kind)
Definition: mkldnn.hpp:1593
bool next_impl()
Advances the next implementation for the given op descriptor.
Definition: mkldnn.hpp:1270
mkldnn_status_t MKLDNN_API mkldnn_inner_product_backward_weights_desc_init(mkldnn_inner_product_desc_t *ip_desc, const mkldnn_memory_desc_t *src_desc, const mkldnn_memory_desc_t *diff_weights_desc, const mkldnn_memory_desc_t *diff_bias_desc, const mkldnn_memory_desc_t *diff_dst_desc)
Initializes an inner product descriptor ip_desc for backward propagation with respect to weights usin...
blocked weights format
Definition: mkldnn_types.h:264
mkldnn_deconvolution_desc_t data
Definition: mkldnn.hpp:1670
desc(prop_kind aprop_kind, const memory::desc &diff_data_desc, const memory::desc &data_desc, T epsilon, unsigned flags)
Definition: mkldnn.hpp:2649
blocked weights format
Definition: mkldnn_types.h:302
Definition: mkldnn.hpp:224
weights format with additional buffer size equal to the number of output channels and containing the ...
Definition: mkldnn_types.h:274
primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e, const lrn_forward::primitive_desc &hint_fwd_pd)
Definition: mkldnn.hpp:2096
float get_clipping() const
Definition: mkldnn.hpp:2978
weights grad.
Definition: mkldnn_types.h:1184
4D data tensor with the physical layout nchw, used in Caffe.
Definition: mkldnn_types.h:157
Definition: mkldnn.hpp:321
mkldnn_status_t MKLDNN_API mkldnn_rnn_forward_desc_init(mkldnn_rnn_desc_t *rnn_desc, mkldnn_prop_kind_t prop_kind, const mkldnn_rnn_cell_desc_t *rnn_cell_desc, const mkldnn_rnn_direction_t direction, const mkldnn_memory_desc_t *src_layer_desc, const mkldnn_memory_desc_t *src_iter_desc, const mkldnn_memory_desc_t *weights_layer_desc, const mkldnn_memory_desc_t *weights_iter_desc, const mkldnn_memory_desc_t *bias_desc, const mkldnn_memory_desc_t *dst_layer_desc, const mkldnn_memory_desc_t *dst_iter_desc)
Initializes a rnn descriptor rnn_desc for forward propagation using prop_kind, rnn_cell_desc, direction, and memory descriptors.
void append_eltwise(float scale, algorithm alg, float alpha, float beta)
Definition: mkldnn.hpp:395
primitive kind
Definition: mkldnn_types.h:1150
blocked data format
Definition: mkldnn_types.h:250
desc(algorithm aalgorithm, const memory::desc &diff_src_desc, const memory::desc &weights_desc, const memory::desc &diff_dst_desc, const memory::dims strides, const memory::dims dilates, const memory::dims padding_l, const memory::dims padding_r, const padding_kind apadding_kind)
Definition: mkldnn.hpp:1815
int get_state_count() const
Definition: mkldnn.hpp:2987
blocked weights format
Definition: mkldnn_types.h:285
Definition: mkldnn.hpp:317
desc(algorithm aalgorithm, const memory::desc &diff_src_desc, const memory::desc &diff_dst_desc, const memory::dims &strides, const memory::dims &kernel, const memory::dims &padding_l, const memory::dims &padding_r, const padding_kind apadding_kind)
Definition: mkldnn.hpp:2207
batch_normalization_forward(const primitive_desc &aprimitive_desc, const primitive::at &src, const primitive::at &weights, const memory &dst, const memory &mean, const memory &variance, const memory &workspace)
Definition: mkldnn.hpp:2544
kind
Definition: mkldnn.hpp:3268
desc(algorithm aalgorithm, const memory::desc &diff_src_desc, const memory::desc &weights_desc, const memory::desc &diff_dst_desc, const memory::dims strides, const memory::dims padding_l, const memory::dims padding_r, const padding_kind apadding_kind)
Definition: mkldnn.hpp:1460
Definition: mkldnn.hpp:339
desc(prop_kind aprop_kind, rnn_cell::desc cell, const rnn_direction direction, const memory::desc &src_layer_desc, const memory::desc &src_iter_desc, const memory::desc &weights_layer_desc, const memory::desc &weights_iter_desc, const memory::desc &bias_desc, const memory::desc &dst_layer_desc, const memory::desc &dst_iter_desc)
Definition: mkldnn.hpp:2996
mkldnn_status_t MKLDNN_API mkldnn_inner_product_forward_desc_init(mkldnn_inner_product_desc_t *ip_desc, mkldnn_prop_kind_t prop_kind, const mkldnn_memory_desc_t *src_desc, const mkldnn_memory_desc_t *weights_desc, const mkldnn_memory_desc_t *bias_desc, const mkldnn_memory_desc_t *dst_desc)
Initializes an inner product descriptor ip_desc for forward propagation using prop_kind (possible val...