// Copyright (c) Facebook, Inc. and its affiliates.
// All rights reserved.
//
// This source code is licensed under the BSD-style license found in the
// LICENSE file in the root directory of this source tree.

#include <ATen/functorch/BatchRulesHelper.h>
#include <ATen/functorch/PlumbingHelper.h>
#include <ATen/functorch/BatchedFallback.h>
#include <ATen/core/dispatch/Dispatcher.h>

namespace at { namespace functorch {
// Flattens out all dims except the batch dim, and also moves batch dim
// (if it exists) to front.
at::Tensor flatten_logical(const Tensor& tensor, optional<int64_t> bdim) {
  if (bdim.has_value()) {
    auto result = moveBatchDimToFront(tensor, bdim);
    if (result.dim() > 1) {
      return result.flatten(1);
    } else {
      return result;
    }
  } else {
    return tensor.flatten();
  }
}

std::tuple<at::Tensor,optional<int64_t>>
mse_loss_batch_rule(const at::Tensor& self, optional<int64_t> self_bdim, const at::Tensor& target,
          optional<int64_t> target_bdim, int64_t reduction) {
  auto self_ = flatten_logical(self, self_bdim);
  auto target_ = flatten_logical(target, target_bdim);
  auto result = at::mse_loss(self_, target_, Reduction::None);
  if (result.dim() == 1) {
    return std::make_tuple(result, 0);
  } else if (reduction == Reduction::None) {
    DimVector end_shape;
    const auto batched_elem = self_bdim.has_value() ?
        moveBatchDimToFront(self, self_bdim) : moveBatchDimToFront(target, target_bdim);
    return std::make_tuple(result.reshape(batched_elem.sizes()), 0);
  } else if (reduction == Reduction::Sum) {
    return std::make_tuple(result.sum(-1), 0);
  } else if (reduction == Reduction::Mean) {
    return std::make_tuple(result.mean(-1), 0);
  }
  TORCH_INTERNAL_ASSERT(false);
};

static Tensor apply_loss_reduction(const at::Tensor& unreduced, int64_t reduction) {
  if (reduction == at::Reduction::Mean) {
    return unreduced.mean();
  } else if (reduction == at::Reduction::Sum) {
    return unreduced.sum();
  }
  return unreduced;
}

Tensor binary_cross_entropy_plumbing(
    const Tensor& self, const Tensor& target,
    const optional<Tensor>& weight, int64_t reduction) {
  auto maybe_layer = maybeCurrentDynamicLayer();
  TORCH_INTERNAL_ASSERT(maybe_layer.has_value());
  int64_t cur_level = maybe_layer->layerId();

  if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(target, cur_level)
      && !isBatchedAtLevel(weight, cur_level)) {
    c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched);
    return at::binary_cross_entropy(self, target, weight, reduction);
  }

  Tensor self_value;
  optional<int64_t> self_bdim;
  std::tie(self_value, self_bdim) = unwrapTensorAtLevel(self, cur_level);
  Tensor target_value;
  optional<int64_t> target_bdim;
  std::tie(target_value, target_bdim) = unwrapTensorAtLevel(target, cur_level);

  Tensor result;
  if (self_bdim || target_bdim) {
    c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched);
    const auto bdim_size = get_bdim_size2(self_value, self_bdim, target_value, target_bdim);
    auto self_ = moveBatchDimToFront(self_value, self_bdim);
    auto target_ = moveBatchDimToFront(target_value, target_bdim);
    self_ = ensure_has_bdim(self_, self_bdim.has_value(), bdim_size);
    target_ = ensure_has_bdim(target_, target_bdim.has_value(), bdim_size);
    result = at::binary_cross_entropy(self_, target_, nullopt, Reduction::None);
    result = makeBatched(result, 0, cur_level);
  } else {
    c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched);
    result = at::binary_cross_entropy(self_value, target_value, nullopt, Reduction::None);
  }
  if (weight.has_value() && weight->defined()) {
    result = result * weight.value();
  }
  return apply_loss_reduction(result, reduction);
}

Tensor binary_cross_entropy_backward_plumbing(
    const Tensor& grad, const Tensor& input, const Tensor& target,
    const c10::optional<Tensor>& weight_opt, int64_t reduction) {
  auto maybe_layer = maybeCurrentDynamicLayer();
  TORCH_INTERNAL_ASSERT(maybe_layer.has_value());
  int64_t cur_level = maybe_layer->layerId();

  if (!areAnyBatchedAtLevel({grad, input, target, weight_opt}, cur_level)) {
    c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched);
    return at::binary_cross_entropy_backward(grad, input, target, weight_opt, reduction);
  }

  Tensor grad_value;
  optional<int64_t> grad_bdim;
  std::tie(grad_value, grad_bdim) = unwrapTensorAtLevel(
      reduction == Reduction::None ? grad : grad.expand_as(input), cur_level);
  Tensor input_value;
  optional<int64_t> input_bdim;
  std::tie(input_value, input_bdim) = unwrapTensorAtLevel(input, cur_level);
  Tensor target_value;
  optional<int64_t> target_bdim;
  std::tie(target_value, target_bdim) = unwrapTensorAtLevel(target, cur_level);

  Tensor grad_input;
  if (grad_bdim || input_bdim || target_bdim) {
    c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched);
    const auto bdim_size = get_bdim_size3(
        grad_value, grad_bdim, input_value, input_bdim, target_value, target_bdim);

    auto grad_ = moveBatchDimToFront(grad_value, grad_bdim);
    auto input_ = moveBatchDimToFront(input_value, input_bdim);
    auto target_ = moveBatchDimToFront(target_value, target_bdim);

    grad_ = ensure_has_bdim(grad_, grad_bdim.has_value(), bdim_size);
    input_ = ensure_has_bdim(input_, input_bdim.has_value(), bdim_size);
    target_ = ensure_has_bdim(target_, target_bdim.has_value(), bdim_size);

    grad_input = at::binary_cross_entropy_backward(
        grad_, input_, target_, nullopt, Reduction::None);
    grad_input = makeBatched(grad_input, 0, cur_level);
  } else {
    c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched);
    grad_input = at::binary_cross_entropy_backward(
        grad_value, input_value, target_value, nullopt, Reduction::None);
  }
  if (weight_opt.has_value() && weight_opt->defined()) {
    grad_input = grad_input * weight_opt.value();
  }
  if (reduction == Reduction::Mean) {
    grad_input.div_(input.numel());
  }
  return grad_input;
}

std::tuple<Tensor, Tensor> nll_loss_forward_decomposition(
    const Tensor & self,
    const Tensor & target,
    const c10::optional<Tensor> & weight,
    int64_t reduction, int64_t ignore_index) {

  // self can be [N, C, ...] or [C]
  // target can be [N, ...] or []

  int64_t channel_dim = 1;
  if (self.dim() < 2) {
    channel_dim = 0;
  }
  auto self_ = self;
  Tensor weight_;

  if (weight && weight->defined()) {
    // Here is a specific case with reduction mean and non-batched tensors
    // https://github.com/pytorch/pytorch/issues/61309
    // In this case weight is cancelled: w * x[t] / w -> x[t]
    if (!(reduction == Reduction::Mean && self_.dim() < 2)) {
      // reshape weights to [1, C, 1, ..., 1]
      auto shape = weight->sizes();
      VmapDimVector new_shape(self_.dim(), 1);
      new_shape[channel_dim] = shape[0];
      weight_ = weight->reshape(new_shape);
      self_ = self_ * weight_;
    }
  }
  auto target_ = target.unsqueeze(channel_dim);
  // target can be [N, 1, ...] or [1]

  auto result = -at::gather(self_, channel_dim, target_).squeeze(channel_dim);
  auto total_weight = at::full(
      {}, result.numel(), self_.scalar_type(),
      self_.layout(), self_.device(), nullopt);

  bool has_ignore_index = ignore_index >= 0;
  Tensor ignore_index_mask;
  if (has_ignore_index) {
    ignore_index_mask = target != ignore_index;
    result = result * ignore_index_mask;
    total_weight = ignore_index_mask.sum().to(self_);
  }

  // Apply the reduction
  if (result.dim() > 0) {
    if (reduction == Reduction::Sum) {
      result = result.sum();
    } else if (reduction == Reduction::Mean) {
      if (!weight || !weight->defined()) {
        if (has_ignore_index) {
          TORCH_INTERNAL_ASSERT(ignore_index_mask.defined());
          // total_weight is ignore_index_mask.sum()
          result = result.sum() / total_weight;
        } else {
          result = result.mean();
        }
      } else {
        TORCH_INTERNAL_ASSERT(weight_.defined());
        weight_ = weight_.expand(self_.sizes());
        auto wsum = at::gather(weight_, channel_dim, target_).squeeze(channel_dim);
        if (has_ignore_index) {
          TORCH_INTERNAL_ASSERT(ignore_index_mask.defined());
          wsum = wsum * ignore_index_mask;
        }
        wsum = wsum.sum();
        result = result.sum() / wsum;
        total_weight = wsum;
      }
    }
  } else if (reduction == Reduction::Mean && weight && weight->defined()) {
    // here weight is [C] and target is [1]
    auto wsum = at::gather(*weight, channel_dim, target_).squeeze(channel_dim);
    if (has_ignore_index) {
      TORCH_INTERNAL_ASSERT(ignore_index_mask.defined());
      wsum = wsum * ignore_index_mask;
    }
    total_weight = wsum.sum();
  }

  return std::make_tuple(result, total_weight);
}

at::Tensor nll_loss_backward_decomposition(
    const at::Tensor & grad_output, const at::Tensor & self,
    const at::Tensor & target, const c10::optional<at::Tensor> & weight,
    int64_t reduction, int64_t ignore_index, const at::Tensor & total_weight) {

  int64_t channel_dim = 1;
  if (self.dim() < 2) {
    channel_dim = 0;
  }
  auto target_ = target.unsqueeze(channel_dim);

  auto grad_output_ = grad_output;
  if (reduction == Reduction::Mean) {
    grad_output_ = grad_output_ / total_weight;
  }

  auto grad_input = at::zeros_like(self);
  grad_input = at::scatter(grad_input, channel_dim, target_, -1.0);

  if (grad_output_.dim() < grad_input.dim() && grad_output_.dim() > 0) {
    grad_output_ = grad_output_.unsqueeze(channel_dim);
  }

  Tensor weight_;
  if (weight && weight->defined()) {
    auto self_ = self;
    auto shape = weight->sizes();
    VmapDimVector new_shape(self_.dim(), 1);
    new_shape[channel_dim] = shape[0];
    weight_ = weight->reshape(new_shape);
    grad_output_ = grad_output_ * weight_;
  }

  bool has_ignore_index = ignore_index >= 0;
  Tensor ignore_index_mask;
  if (has_ignore_index) {
    ignore_index_mask = target_ != ignore_index;
    grad_output_ = grad_output_ * ignore_index_mask;
  }

  return grad_input * grad_output_;
}

TORCH_LIBRARY_IMPL(aten, FuncTorchBatched, m) {
  m.impl("nll_loss_forward", nll_loss_forward_decomposition);
  m.impl("nll_loss2d_forward", nll_loss_forward_decomposition);
  m.impl("nll_loss_backward", nll_loss_backward_decomposition);
  m.impl("nll_loss2d_backward", nll_loss_backward_decomposition);
  VMAP_SUPPORT(mse_loss, mse_loss_batch_rule);
  // mse_loss_backwards uses a decomposition for its batch rule
  m.impl("binary_cross_entropy", binary_cross_entropy_plumbing);
  m.impl("binary_cross_entropy_backward", binary_cross_entropy_backward_plumbing);
}

}}
