Open3D (C++ API)  0.18.0
Loading...
Searching...
No Matches
ShapeChecking.h
Go to the documentation of this file.
1// ----------------------------------------------------------------------------
2// - Open3D: www.open3d.org -
3// ----------------------------------------------------------------------------
4// Copyright (c) 2018-2023 www.open3d.org
5// SPDX-License-Identifier: MIT
6// ----------------------------------------------------------------------------
7
8#pragma once
9#include <iostream>
10#include <string>
11#include <tuple>
12#include <vector>
13
14namespace open3d {
15namespace ml {
16namespace op_util {
17
19class DimValue {
20public:
21 DimValue() : value_(0), constant_(false) {}
22 DimValue(int64_t v) : value_(v), constant_(true) {}
24 if (constant_ && b.constant_)
25 value_ *= b.value_;
26 else
27 constant_ = false;
28 return *this;
29 }
30 std::string ToString() const {
31 if (constant_)
32 return std::to_string(value_);
33 else
34 return "?";
35 }
36 int64_t& value() {
37 if (!constant_) throw std::runtime_error("DimValue is not constant");
38 return value_;
39 }
40 bool& constant() { return constant_; }
41
42private:
43 int64_t value_;
44 bool constant_;
45};
46
47inline DimValue UnknownValue() { return DimValue(); }
48
50class Dim {
51public:
52 explicit Dim() : value_(0), constant_(false), origin_(this) {}
53
54 explicit Dim(const std::string& name)
55 : value_(0), constant_(false), origin_(this), name_(name) {}
56
57 Dim(int64_t value, const std::string& name = "")
58 : value_(value), constant_(true), origin_(nullptr), name_(name) {}
59
60 Dim(const Dim& other)
61 : value_(other.value_),
62 constant_(other.constant_),
63 origin_(other.origin_),
64 name_(other.name_) {}
65
66 ~Dim() {}
67
68 Dim& operator=(const Dim&) = delete;
69
70 int64_t& value() {
71 if (origin_)
72 return origin_->value_;
73 else
74 return value_;
75 }
76
77 bool& constant() {
78 if (origin_)
79 return origin_->constant_;
80 else
81 return constant_;
82 }
83
86 bool assign(int64_t a) {
87 if (!constant()) {
88 value() = a;
89 constant() = true;
90 }
91 return value() == a;
92 }
93
94 std::string ToString(bool show_value = true) {
95 if (name_.size()) {
96 if (show_value)
97 return name_ + "(" +
98 (constant() ? std::to_string(value()) : "?") + ")";
99 else
100 return name_;
101 }
102 if (constant())
103 return std::to_string(value());
104 else
105 return "?";
106 }
107
108private:
109 int64_t value_;
110 bool constant_;
111 Dim* origin_;
112 std::string name_;
113};
114
115//
116// Dim expression operator classes
117//
118
119struct DimXPlus {
120 static bool constant() { return true; };
121 static int64_t apply(int64_t a, int64_t b) { return a + b; }
122
123 template <class T1, class T2>
124 static bool backprop(int64_t ans, T1 a, T2 b) {
125 if (!a.constant() && a.constant() == b.constant()) {
126 std::string exstr =
127 GetString(a, false) + ToString() + GetString(b, false);
128 throw std::runtime_error("Illegal dim expression: " + exstr);
129 return false;
130 } else if (!a.constant()) {
131 return a.assign(ans - b.value());
132 } else {
133 return b.assign(ans - a.value());
134 }
135 }
136
137 static std::string ToString() { return "+"; }
138};
139
140struct DimXMinus {
141 static bool constant() { return true; };
142 static int64_t apply(int64_t a, int64_t b) { return a - b; }
143
144 template <class T1, class T2>
145 static bool backprop(int64_t ans, T1 a, T2 b) {
146 if (!a.constant() && a.constant() == b.constant()) {
147 std::string exstr =
148 GetString(a, false) + ToString() + GetString(b, false);
149 throw std::runtime_error("Illegal dim expression: " + exstr);
150 return false;
151 } else if (!a.constant()) {
152 return a.assign(ans + b.value());
153 } else {
154 return b.assign(a.value() - ans);
155 }
156 }
157
158 static std::string ToString() { return "-"; }
159};
160
162 static bool constant() { return true; };
163 static int64_t apply(int64_t a, int64_t b) { return a * b; }
164
165 template <class T1, class T2>
166 static bool backprop(int64_t ans, T1 a, T2 b) {
167 std::string exstr =
168 GetString(a, false) + ToString() + GetString(b, false);
169 throw std::runtime_error("Illegal dim expression: " + exstr);
170 return false;
171 }
172
173 static std::string ToString() { return "*"; }
174};
175
177 static bool constant() { return true; };
178 static int64_t apply(int64_t a, int64_t b) { return a / b; }
179
180 template <class T1, class T2>
181 static bool backprop(int64_t ans, T1 a, T2 b) {
182 std::string exstr =
183 GetString(a, false) + ToString() + GetString(b, false);
184 throw std::runtime_error("Illegal dim expression: " + exstr);
185 return false;
186 }
187
188 static std::string ToString() { return "/"; }
189};
190
191struct DimXOr {
192 static bool constant() { return false; };
193 static int64_t apply(int64_t a, int64_t b) {
194 throw std::runtime_error("Cannot evaluate OR expression");
195 return 0;
196 }
197 template <class T1, class T2>
198 static bool backprop(int64_t ans, T1 a, T2 b) {
199 return a.assign(ans) || b.assign(ans);
200 }
201
202 static std::string ToString() { return "||"; }
203};
204
206template <class TLeft, class TRight, class TOp>
207class DimX {
208public:
209 static DimX<TLeft, TRight, TOp> Create(TLeft left, TRight right) {
210 return DimX(left, right);
211 }
212
213 int64_t value() {
214 if (constant_) {
215 return TOp::apply(left_.value(), right_.value());
216 }
217 return 0;
218 }
219
220 bool& constant() { return constant_; }
221
223 bool assign(int64_t a) {
224 if (constant_) {
225 return value() == a;
226 } else {
227 return TOp::backprop(a, left_, right_);
228 }
229 }
230
231 std::string ToString(bool show_value = true) {
232 return left_.ToString(show_value) + TOp::ToString() +
233 right_.ToString(show_value);
234 }
235
236private:
237 DimX(TLeft left, TRight right) : left_(left), right_(right) {
238 constant_ = left.constant() && right.constant() && TOp::constant();
239 }
240 TLeft left_;
241 TRight right_;
242 bool constant_;
243};
244
245//
246// define operators for dim expressions
247//
248
249#define DEFINE_DIMX_OPERATOR(opclass, symbol) \
250 inline DimX<Dim, Dim, opclass> operator symbol(Dim a, Dim b) { \
251 return DimX<Dim, Dim, opclass>::Create(a, b); \
252 } \
253 \
254 template <class TL, class TR, class TOp> \
255 inline DimX<Dim, DimX<TL, TR, TOp>, opclass> operator symbol( \
256 Dim a, DimX<TL, TR, TOp>&& b) { \
257 return DimX<Dim, DimX<TL, TR, TOp>, opclass>::Create(a, b); \
258 } \
259 \
260 template <class TL, class TR, class TOp> \
261 inline DimX<DimX<TL, TR, TOp>, Dim, opclass> operator symbol( \
262 DimX<TL, TR, TOp>&& a, Dim b) { \
263 return DimX<DimX<TL, TR, TOp>, Dim, opclass>::Create(a, b); \
264 } \
265 \
266 template <class TL1, class TR1, class TOp1, class TL2, class TR2, \
267 class TOp2> \
268 inline DimX<DimX<TL1, TR1, TOp1>, DimX<TL2, TR2, TOp2>, opclass> \
269 operator symbol(DimX<TL1, TR1, TOp1>&& a, DimX<TL2, TR2, TOp2>&& b) { \
270 return DimX<DimX<TL1, TR1, TOp1>, DimX<TL2, TR2, TOp2>, \
271 opclass>::Create(a, b); \
272 }
273
274DEFINE_DIMX_OPERATOR(DimXPlus, +)
275DEFINE_DIMX_OPERATOR(DimXMinus, -)
276DEFINE_DIMX_OPERATOR(DimXMultiply, *)
277DEFINE_DIMX_OPERATOR(DimXDivide, /)
278DEFINE_DIMX_OPERATOR(DimXOr, ||)
279#undef DEFINE_DIMX_OPERATOR
280
281//
282// define operators for comparing DimValue to dim expressions.
283// Using these operators will try to assign the dim value to the expression.
284//
285
286template <class TLeft, class TRight, class TOp>
288 if (a.constant()) {
289 auto b_copy(b);
290 return b_copy.assign(a.value());
291 } else
292 return true;
293}
294
295inline bool operator==(DimValue a, Dim b) {
296 if (a.constant())
297 return b.assign(a.value());
298 else
299 return true;
300}
301
302//
303// some helper classes
304//
305
306template <class... args>
307struct CountArgs {
308 static const size_t value = sizeof...(args);
309};
310
311template <class TLeft, class TRight, class TOp>
312std::string GetString(DimX<TLeft, TRight, TOp> a, bool show_value = true) {
313 return a.ToString(show_value);
314}
315
316inline std::string GetString(Dim a, bool show_value = true) {
317 return a.ToString(show_value);
318}
319
320template <class TLeft, class TRight, class TOp>
322 return a.value();
323}
324
325template <class TLeft, class TRight, class TOp>
326int64_t GetValue(DimX<TLeft, TRight, TOp> a, int64_t unknown_dim_value) {
327 if (a.constant()) {
328 return a.value();
329 } else {
330 return unknown_dim_value;
331 }
332 return a.value();
333}
334
335inline int64_t GetValue(Dim a) { return a.value(); }
336
337inline int64_t GetValue(Dim a, int64_t unknown_dim_value) {
338 if (a.constant()) {
339 return a.value();
340 } else {
341 return unknown_dim_value;
342 }
343}
344
345inline std::string CreateDimXString() { return std::string(); }
346
347template <class TDimX>
348std::string CreateDimXString(TDimX dimex) {
349 return GetString(dimex);
350}
351
352template <class TDimX, class... TArgs>
353std::string CreateDimXString(TDimX dimex, TArgs... args) {
354 return GetString(dimex) + ", " + CreateDimXString(args...);
355}
356
357template <class TDimX>
358void CreateDimVector(std::vector<int64_t>& out,
359 int64_t unknown_dim_value,
360 TDimX dimex) {
361 out.push_back(GetValue(dimex, unknown_dim_value));
362}
363
364template <class TDimX, class... TArgs>
365void CreateDimVector(std::vector<int64_t>& out,
366 int64_t unknown_dim_value,
367 TDimX dimex,
368 TArgs... args) {
369 out.push_back(GetValue(dimex, unknown_dim_value));
370 CreateDimVector(out, unknown_dim_value, args...);
371}
372
373template <class TDimX>
374std::vector<int64_t> CreateDimVector(int64_t unknown_dim_value, TDimX dimex) {
375 std::vector<int64_t> out;
376 CreateDimVector(out, unknown_dim_value, dimex);
377 return out;
378}
379
380template <class TDimX, class... TArgs>
381std::vector<int64_t> CreateDimVector(int64_t unknown_dim_value,
382 TDimX dimex,
383 TArgs... args) {
384 std::vector<int64_t> out;
385 CreateDimVector(out, unknown_dim_value, dimex, args...);
386 return out;
387}
388
389//
390// classes which check if the dim value is compatible with the expression
391//
392
393template <class TLeft, class TRight, class TOp>
395 bool status = (lhs == std::forward<DimX<TLeft, TRight, TOp>>(rhs));
396 return status;
397}
398
399inline bool CheckDim(const DimValue& lhs, Dim d) {
400 bool status = lhs == d;
401 return status;
402}
403
412
413template <CSOpt Opt = CSOpt::NONE, class TDimX>
414bool _CheckShape(const std::vector<DimValue>& shape, TDimX&& dimex) {
415 // check rank
416 const int rank_diff = shape.size() - 1;
417 if (Opt != CSOpt::NONE) {
418 if (rank_diff < 0) {
419 return false;
420 }
421 } else {
422 if (rank_diff != 0) {
423 return false;
424 }
425 }
426
427 // check dim
428 bool status;
429 if (Opt == CSOpt::COMBINE_FIRST_DIMS) {
430 DimValue s(1);
431 for (int i = 0; i < rank_diff + 1; ++i) s *= shape[i];
432 status = CheckDim(s, std::forward<TDimX>(dimex));
433 } else if (Opt == CSOpt::IGNORE_FIRST_DIMS) {
434 status = CheckDim(shape[rank_diff], std::forward<TDimX>(dimex));
435 } else if (Opt == CSOpt::COMBINE_LAST_DIMS) {
436 DimValue s(1);
437 for (DimValue x : shape) s *= x;
438 status = CheckDim(s, std::forward<TDimX>(dimex));
439 } else {
440 status = CheckDim(shape[0], std::forward<TDimX>(dimex));
441 }
442 return status;
443}
444
445template <CSOpt Opt = CSOpt::NONE, class TDimX, class... TArgs>
446bool _CheckShape(const std::vector<DimValue>& shape,
447 TDimX&& dimex,
448 TArgs&&... args) {
449 // check rank
450 const int rank_diff = shape.size() - (CountArgs<TArgs...>::value + 1);
451 if (Opt != CSOpt::NONE) {
452 if (rank_diff < 0) {
453 return false;
454 }
455 } else {
456 if (rank_diff != 0) {
457 return false;
458 }
459 }
460
461 // check dim
462 bool status;
463 if (Opt == CSOpt::COMBINE_FIRST_DIMS) {
464 DimValue s(1);
465 for (int i = 0; i < rank_diff + 1; ++i) s *= shape[i];
466 status = CheckDim(s, std::forward<TDimX>(dimex));
467 } else if (Opt == CSOpt::IGNORE_FIRST_DIMS) {
468 status = CheckDim(shape[rank_diff], std::forward<TDimX>(dimex));
469 } else {
470 status = CheckDim(shape[0], std::forward<TDimX>(dimex));
471 }
472
473 const int offset = 1 + (Opt == CSOpt::COMBINE_FIRST_DIMS ||
475 ? rank_diff
476 : 0);
477 std::vector<DimValue> shape2(shape.begin() + offset, shape.end());
478 bool status2 = _CheckShape<Opt>(shape2, std::forward<TArgs>(args)...);
479
480 return status && status2;
481}
482
573template <CSOpt Opt = CSOpt::NONE, class TDimX, class... TArgs>
574std::tuple<bool, std::string> CheckShape(const std::vector<DimValue>& shape,
575 TDimX&& dimex,
576 TArgs&&... args) {
577 const bool status = _CheckShape<Opt>(shape, std::forward<TDimX>(dimex),
578 std::forward<TArgs>(args)...);
579 if (status) {
580 return std::make_tuple(status, std::string());
581 } else {
582 const int rank_diff = shape.size() - (CountArgs<TArgs...>::value + 1);
583
584 // generate string for the actual shape. This is a bit involved because
585 // of the many options.
586 std::string shape_str;
587 if (rank_diff <= 0) {
588 shape_str = "[";
589 for (int i = 0; i < int(shape.size()); ++i) {
590 shape_str += shape[i].ToString();
591 if (i + 1 < int(shape.size())) shape_str += ", ";
592 }
593 shape_str += "]";
594 } else {
595 if (Opt == CSOpt::COMBINE_FIRST_DIMS) {
596 shape_str += "[";
597 for (int i = 0; i < rank_diff; ++i) {
598 shape_str += shape[i].ToString();
599 if (i + 1 < int(shape.size())) shape_str += "*";
600 }
601 } else if (Opt == CSOpt::IGNORE_FIRST_DIMS) {
602 shape_str += "(";
603 for (int i = 0; i < rank_diff; ++i) {
604 shape_str += shape[i].ToString();
605 if (i + 1 < rank_diff) shape_str += ", ";
606 }
607 shape_str += ")[";
608 } else {
609 shape_str = "[";
610 }
611 int start = 0;
612 if (Opt == CSOpt::COMBINE_FIRST_DIMS ||
614 start = rank_diff;
615 }
616
617 int end = shape.size();
618 if (Opt == CSOpt::COMBINE_LAST_DIMS) {
619 end -= rank_diff + 1;
620 } else if (Opt == CSOpt::IGNORE_LAST_DIMS) {
621 end -= rank_diff;
622 }
623 for (int i = start; i < end; ++i) {
624 shape_str += shape[i].ToString();
625 if (i + 1 < end) shape_str += ", ";
626 }
627 if (Opt == CSOpt::COMBINE_LAST_DIMS) {
628 shape_str += ", ";
629 for (int i = std::max<int>(0, shape.size() - rank_diff - 1);
630 i < int(shape.size()); ++i) {
631 shape_str += shape[i].ToString();
632 if (i + 1 < int(shape.size())) shape_str += "*";
633 }
634 shape_str += "]";
635 } else if (Opt == CSOpt::IGNORE_LAST_DIMS) {
636 shape_str += "](";
637 for (int i = std::max<int>(0, shape.size() - rank_diff);
638 i < int(shape.size()); ++i) {
639 shape_str += shape[i].ToString();
640 if (i + 1 < int(shape.size())) shape_str += ", ";
641 }
642 shape_str += ")";
643 } else {
644 shape_str += "]";
645 }
646 }
647
648 // generate string for the expected shape with the dim expressions
649 std::string expected_shape;
650 if ((CountArgs<TArgs...>::value + 1) == 1) {
651 expected_shape = "[" + GetString(dimex) + "]";
652
653 } else {
654 expected_shape = "[" + GetString(dimex) + ", " +
655 CreateDimXString(args...) + "]";
656 }
657
658 std::string errstr;
659 // print rank information if there is a problem with the rank
660 if ((Opt != CSOpt::NONE && rank_diff < 0) ||
661 (Opt == CSOpt::NONE && rank_diff != 0)) {
662 errstr = "got rank " + std::to_string(shape.size()) + " " +
663 shape_str + ", expected rank " +
664 std::to_string(CountArgs<TArgs...>::value + 1) + " " +
665 expected_shape;
666 } else { // rank is OK print just the shapes
667 errstr = "got " + shape_str + ", expected " + expected_shape;
668 }
669 return std::make_tuple(status, errstr);
670 }
671}
672
673} // namespace op_util
674} // namespace ml
675} // namespace open3d
#define DEFINE_DIMX_OPERATOR(opclass, symbol)
Definition ShapeChecking.h:249
Class for dimensions for which the value should be inferred.
Definition ShapeChecking.h:50
int64_t & value()
Definition ShapeChecking.h:70
bool assign(int64_t a)
Definition ShapeChecking.h:86
Dim & operator=(const Dim &)=delete
Dim(const Dim &other)
Definition ShapeChecking.h:60
Dim()
Definition ShapeChecking.h:52
Dim(const std::string &name)
Definition ShapeChecking.h:54
~Dim()
Definition ShapeChecking.h:66
bool & constant()
Definition ShapeChecking.h:77
std::string ToString(bool show_value=true)
Definition ShapeChecking.h:94
Dim(int64_t value, const std::string &name="")
Definition ShapeChecking.h:57
Class for representing a possibly unknown dimension value.
Definition ShapeChecking.h:19
DimValue & operator*=(const DimValue &b)
Definition ShapeChecking.h:23
DimValue(int64_t v)
Definition ShapeChecking.h:22
bool & constant()
Definition ShapeChecking.h:40
DimValue()
Definition ShapeChecking.h:21
int64_t & value()
Definition ShapeChecking.h:36
std::string ToString() const
Definition ShapeChecking.h:30
Dim expression class.
Definition ShapeChecking.h:207
std::string ToString(bool show_value=true)
Definition ShapeChecking.h:231
bool assign(int64_t a)
assigns a value to the expression
Definition ShapeChecking.h:223
static DimX< TLeft, TRight, TOp > Create(TLeft left, TRight right)
Definition ShapeChecking.h:209
int64_t value()
Definition ShapeChecking.h:213
bool & constant()
Definition ShapeChecking.h:220
std::string name
Definition FilePCD.cpp:39
int offset
Definition FilePCD.cpp:45
bool operator==(DimValue a, DimX< TLeft, TRight, TOp > &&b)
Definition ShapeChecking.h:287
std::string GetString(DimX< TLeft, TRight, TOp > a, bool show_value=true)
Definition ShapeChecking.h:312
DimValue UnknownValue()
Definition ShapeChecking.h:47
CSOpt
Check shape options.
Definition ShapeChecking.h:405
std::tuple< bool, std::string > CheckShape(const std::vector< DimValue > &shape, TDimX &&dimex, TArgs &&... args)
Definition ShapeChecking.h:574
void CreateDimVector(std::vector< int64_t > &out, int64_t unknown_dim_value, TDimX dimex)
Definition ShapeChecking.h:358
bool CheckDim(const DimValue &lhs, DimX< TLeft, TRight, TOp > &&rhs)
Definition ShapeChecking.h:394
bool _CheckShape(const std::vector< DimValue > &shape, TDimX &&dimex)
Definition ShapeChecking.h:414
int64_t GetValue(DimX< TLeft, TRight, TOp > a)
Definition ShapeChecking.h:321
std::string CreateDimXString()
Definition ShapeChecking.h:345
Definition PinholeCameraIntrinsic.cpp:16
Definition ShapeChecking.h:307
static const size_t value
Definition ShapeChecking.h:308
Definition ShapeChecking.h:176
static bool backprop(int64_t ans, T1 a, T2 b)
Definition ShapeChecking.h:181
static bool constant()
Definition ShapeChecking.h:177
static int64_t apply(int64_t a, int64_t b)
Definition ShapeChecking.h:178
static std::string ToString()
Definition ShapeChecking.h:188
Definition ShapeChecking.h:140
static std::string ToString()
Definition ShapeChecking.h:158
static bool backprop(int64_t ans, T1 a, T2 b)
Definition ShapeChecking.h:145
static int64_t apply(int64_t a, int64_t b)
Definition ShapeChecking.h:142
static bool constant()
Definition ShapeChecking.h:141
Definition ShapeChecking.h:161
static bool constant()
Definition ShapeChecking.h:162
static std::string ToString()
Definition ShapeChecking.h:173
static int64_t apply(int64_t a, int64_t b)
Definition ShapeChecking.h:163
static bool backprop(int64_t ans, T1 a, T2 b)
Definition ShapeChecking.h:166
Definition ShapeChecking.h:191
static bool constant()
Definition ShapeChecking.h:192
static int64_t apply(int64_t a, int64_t b)
Definition ShapeChecking.h:193
static bool backprop(int64_t ans, T1 a, T2 b)
Definition ShapeChecking.h:198
static std::string ToString()
Definition ShapeChecking.h:202
Definition ShapeChecking.h:119
static bool backprop(int64_t ans, T1 a, T2 b)
Definition ShapeChecking.h:124
static bool constant()
Definition ShapeChecking.h:120
static std::string ToString()
Definition ShapeChecking.h:137
static int64_t apply(int64_t a, int64_t b)
Definition ShapeChecking.h:121