@@ -94,8 +94,18 @@ namespace {
94
94
// expand a single element splat to a multi-GB large tensor.
95
95
// The limit is arbitrary set low to allow expanding small computations, like
96
96
// shape manipulations for example.
97
+
// TODO(b/210478841): Define a constant folding policy that generalizes this.
97
98
constexpr int64_t kFoldExpandSplatEltLimit = 16;
98
99
100
+
// Similarly to the constant above, this is an arbitrary limit into how many
101
+
// elements can be folded by a binary operation folder.
102
+
// This limit doesn't apply to the following special cases:
103
+
// 1) Adding a zero.
104
+
// 2) Multiplying by one.
105
+
// 3) When both operands are splats.
106
+
// TODO(b/210478841): Define a constant folding policy that generalizes this.
107
+
constexpr int64_t kFoldBinaryOpEltLimit = 65536;
108
+
99
109
// Clamps value to the range [lower, upper]. Requires lower <= upper.
100
110
template <typename T>
101
111
static T Clamp(const T& value, const T& lower, const T& upper) {
@@ -5234,6 +5244,22 @@ static Attribute BinaryFolder(Op* op, ArrayRef<Attribute> attrs) {
5234
5244
return {};
5235
5245
}
5236
5246
5247
+
// Special case for folding splats no matter how large.
5248
+
// Only covers the case of both attrs being splats; operation-specific cases
5249
+
// like adding a zero or multiplying by one are handled elsewhere.
5250
+
SplatElementsAttr splat_lhs = lhs.dyn_cast<SplatElementsAttr>();
5251
+
SplatElementsAttr splat_rhs = rhs.dyn_cast<SplatElementsAttr>();
5252
+
if (splat_lhs && splat_rhs) {
5253
+
return SplatElementsAttr::get(
5254
+
type, Convert()(splat_lhs.getSplatValue<ValType>(),
5255
+
splat_rhs.getSplatValue<ValType>()));
5256
+
}
5257
+
5258
+
// Prevent folding if lhs/rhs are too large.
5259
+
if (lhs.getNumElements() > kFoldBinaryOpEltLimit) {
5260
+
return {};
5261
+
}
5262
+
5237
5263
SmallVector<ValType, 6> values;
5238
5264
values.reserve(lhs.getNumElements());
5239
5265
for (const auto zip :
@@ -5315,40 +5341,60 @@ BINARY_FOLDER(RemOp, remainder);
5315
5341
BINARY_FOLDER(MaxOp, max);
5316
5342
BINARY_FOLDER(MinOp, min);
5317
5343
5318
-
OpFoldResult AddOp::fold(ArrayRef<Attribute> attrs) {
5319
-
if (attrs[0] && attrs[1]) {
5320
-
BINARY_FOLDER_INTERNAL(AddOp, std::plus)
5344
+
bool isSplatZero(SplatElementsAttr attr) {
5345
+
if (!attr) return false;
5346
+
if (attr.getElementType().isa<FloatType>()) {
5347
+
return attr.getSplatValue<APFloat>().isZero();
5348
+
} else if (attr.getElementType().isa<IntegerType>()) {
5349
+
return attr.getSplatValue<APInt>().isZero();
5350
+
} else {
5351
+
return false;
5321
5352
}
5353
+
}
5354
+
5355
+
OpFoldResult AddOp::fold(ArrayRef<Attribute> attrs) {
5322
5356
// Handle special case where one operand is 0: x + 0 => x
5323
5357
if (attrs[0] || attrs[1]) {
5324
-
SplatElementsAttr attr = attrs[0] ? attrs[0].dyn_cast<SplatElementsAttr>()
5325
-
: attrs[1].dyn_cast<SplatElementsAttr>();
5326
-
if (!attr) return {};
5327
-
Value result = attrs[0] ? rhs() : lhs();
5328
-
if (attr.getElementType().isa<FloatType>()) {
5329
-
if (attr.getSplatValue<APFloat>().isZero()) return result;
5330
-
} else if (attr.getElementType().isa<IntegerType>()) {
5331
-
if (attr.getSplatValue<APInt>().isZero()) return result;
5332
-
}
5358
+
SplatElementsAttr splat_lhs =
5359
+
attrs[0].dyn_cast_or_null<SplatElementsAttr>();
5360
+
SplatElementsAttr splat_rhs =
5361
+
attrs[1].dyn_cast_or_null<SplatElementsAttr>();
5362
+
if (isSplatZero(splat_lhs))
5363
+
return splat_rhs ? (OpFoldResult)splat_rhs : rhs();
5364
+
if (isSplatZero(splat_rhs))
5365
+
return splat_lhs ? (OpFoldResult)splat_lhs : lhs();
5366
+
}
5367
+
if (attrs[0] && attrs[1]) {
5368
+
BINARY_FOLDER_INTERNAL(AddOp, std::plus)
5333
5369
}
5334
5370
return {};
5335
5371
}
5336
5372
5337
-
OpFoldResult MulOp::fold(ArrayRef<Attribute> attrs) {
5338
-
if (attrs[0] && attrs[1]) {
5339
-
BINARY_FOLDER_INTERNAL(MulOp, std::multiplies);
5373
+
bool isSplatOne(SplatElementsAttr attr) {
5374
+
if (!attr) return false;
5375
+
if (attr.getElementType().isa<FloatType>()) {
5376
+
return attr.getSplatValue<APFloat>().convertToDouble() == 1.0;
5377
+
} else if (attr.getElementType().isa<IntegerType>()) {
5378
+
return attr.getSplatValue<APInt>().getSExtValue() == 1;
5379
+
} else {
5380
+
return false;
5340
5381
}
5382
+
}
5383
+
5384
+
OpFoldResult MulOp::fold(ArrayRef<Attribute> attrs) {
5341
5385
// Handle special case where one operand is 1: x * 1 => x
5342
5386
if (attrs[0] || attrs[1]) {
5343
-
SplatElementsAttr attr = attrs[0] ? attrs[0].dyn_cast<SplatElementsAttr>()
5344
-
: attrs[1].dyn_cast<SplatElementsAttr>();
5345
-
if (!attr) return {};
5346
-
Value result = attrs[0] ? rhs() : lhs();
5347
-
if (attr.getElementType().isa<FloatType>()) {
5348
-
if (attr.getSplatValue<APFloat>().convertToDouble() == 1.0) return result;
5349
-
} else if (attr.getElementType().isa<IntegerType>()) {
5350
-
if (attr.getSplatValue<APInt>().getSExtValue() == 1) return result;
5351
-
}
5387
+
SplatElementsAttr splat_lhs =
5388
+
attrs[0].dyn_cast_or_null<SplatElementsAttr>();
5389
+
SplatElementsAttr splat_rhs =
5390
+
attrs[1].dyn_cast_or_null<SplatElementsAttr>();
5391
+
if (isSplatOne(splat_lhs))
5392
+
return splat_rhs ? (OpFoldResult)splat_rhs : rhs();
5393
+
if (isSplatOne(splat_rhs))
5394
+
return splat_lhs ? (OpFoldResult)splat_lhs : lhs();
5395
+
}
5396
+
if (attrs[0] && attrs[1]) {
5397
+
BINARY_FOLDER_INTERNAL(MulOp, std::multiplies);
5352
5398
}
5353
5399
return {};
5354
5400
}
RetroSearch is an open source project built by @garambo | Open a GitHub Issue
Search and Browse the WWW like it's 1997 | Search results from DuckDuckGo
HTML:
3.2
| Encoding:
UTF-8
| Version:
0.7.4