# -*- coding: utf-8 -*-

import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from scipy.integrate import odeint

# ===========================================================
# 常微分方程式を解くクラス
# ===========================================================
class ODE(object):

    # -------------------------------------------------------
    # コンストラクター
    # -------------------------------------------------------
    def __init__(self, diff_eq, init_con):

        self.diff_eq  = diff_eq
        self.init_con = init_con

        
    # -------------------------------------------------------
    # 常微分方程式の計算
    # -------------------------------------------------------
    def cal_euation(self, t_min, t_max, N):

        t = np.linspace(t_min, t_max, N)             # x の配列の生成
        v = odeint(self.diff_eq, self.init_con, t)   # 方程式の計算

        return v



# -------------------------------------------------------
# 解くべき常微分方程式
# -------------------------------------------------------
def diff_eq(v, t):
    p = 10
    r = 28
    b = 8/3
    dxdt = -p*v[0] + p*v[1]
    dydt = -v[0]*v[2] + r*v[0] -v[1]
    dzdt = v[0]*v[1] -b*v[2]
    return [dxdt, dydt, dzdt]


# -------------------------------------------------------
# プロット
# ------------------------------------------------------- 
def plot(x, y, z):
    fig = plt.figure()

    # ----- プロットの準備 -----
    sol = fig.add_subplot(1,1,1, projection='3d')
    sol.set_xlabel("x", fontsize=20, fontname='serif')
    sol.set_ylabel("y", fontsize=20, fontname='serif')
    sol.set_zlabel("z", fontsize=20, fontname='serif')
    sol.tick_params(axis='both', length=10,  which='major')
    sol.tick_params(axis='both', length=5,  which='minor')
    sol.minorticks_on()

    sol.plot(x, y, z, 'b-', markersize=0)

    # ----- スクリーン表示 -----
    fig.tight_layout()
    plt.show()
        
    # ----- pdf 作成 -----
    fig.savefig('Lorenz.pdf', orientation='portrait', \
                transparent=False, bbox_inches=None, frameon=None)
    fig.clf()



# -------------------------------------------------------
# メイン関数
# -------------------------------------------------------
if __name__ == "__main__":

    N = 10001                                # 分割数
    min_t = 0                                # t の最小
    max_t = 100                              # t の最大
    initial_condition = np.array([0.1, 0.1, 0.1])  # 初期条件

    ode = ODE(diff_eq, initial_condition)
    v = ode.cal_euation(min_t, max_t, N)

    plot(v[:, 0], v[:,1], v[:,2])
    
